Building your diffusion posterior sampling method using SDEs#

This demo shows you how to use deepinv.sampling.PosteriorDiffusion to perform posterior sampling. It also can be used to perform unconditional image generation with arbitrary denoisers, if the data fidelity term is not specified.

This method requires:

The deepinv.sampling.PosteriorDiffusion class can be used to perform posterior sampling for inverse problems. Consider the acquisition model:

\[y = \noise{\forw{x}}\]

where \(\forw{x}\) is the forward operator (e.g., a convolutional operator) and \(\noise{\cdot}\) is the noise operator (e.g., Gaussian noise). This class defines the reverse-time SDE for the posterior distribution \(p(x|y)\) given the data \(y\):

\[d\, x_t = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla_{x_t} \log p_t(x_t | y) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}\]

where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. The drift term and the diffusion coefficient are defined by the underlying (unconditional) forward-time SDE sde. In this example, we will use 2 well-known SDE in the literature: the Variance-Exploding (VE) and Variance-Preserving (VP or DDPM).

The (conditional) score function \(\nabla_{x_t} \log p_t(x_t | y)\) can be decomposed using the Bayes’ rule:

\[\nabla_{x_t} \log p_t(x_t | y) = \nabla_{x_t} \log p_t(x_t) + \nabla_{x_t} \log p_t(y | x_t).\]

The first term is the score function of the unconditional SDE, which is typically approximated by an MMSE denoiser (denoiser) using the well-known Tweedie’s formula, while the second term is approximated by the (noisy) data-fidelity term (data_fidelity). We implement various data-fidelity terms in the user guide.

Note

In this demo, we limit the number of diffusion steps for the sake of speed, but in practice, you should use a larger number of steps to obtain better results.


Let us import the necessary modules, define the denoiser and the SDE.

In this first example, we use the Variance-Exploding SDE, whose forward process is defined as:

\[d\, x_t = g(t) d\, w_t \quad \mbox{where } g(t) = \sigma_{\mathrm{min}}\left( \frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}}\right)^t\sqrt{2\log\frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}} }.\]
import torch
import matplotlib as mpl
import deepinv as dinv
from deepinv.models import NCSNpp

device = dinv.utils.get_device()
dtype = torch.float64
dtype = torch.float32
figsize = 2.5
gif_frequency = 10  # Increase this value to save the GIF saving time
mpl.rcParams["animation.html"] = "jshtml"
Selected CPU device
from deepinv.sampling import (
    PosteriorDiffusion,
    DPSDataFidelity,
    EulerSolver,
    VarianceExplodingDiffusion,
    VariancePreservingDiffusion,
)
from deepinv.optim import ZeroFidelity

# In this example, we use the pre-trained FFHQ-64 model from the
# EDM framework: https://arxiv.org/pdf/2206.00364 .
# The network architecture is from Song et al: https://arxiv.org/abs/2011.13456 .
denoiser = NCSNpp(pretrained="download").to(device)


# The solution is obtained by calling the SDE object with a desired solver (here, Euler).
# The reproducibility of the SDE Solver class can be controlled by providing the pseudo-random number generator.
num_steps = 150
rng = torch.Generator(device).manual_seed(42)
timesteps = torch.linspace(1, 0.001, num_steps)
solver = EulerSolver(timesteps=timesteps, rng=rng)
sde = VarianceExplodingDiffusion(
    device=device,
    dtype=dtype,
)

Reverse-time SDE as sampling process#

When the data fidelity is not given, the posterior diffusion is equivalent to the unconditional diffusion. Sampling is performed by solving the reverse-time SDE. To do so, we generate a reverse-time trajectory.

model = PosteriorDiffusion(
    data_fidelity=ZeroFidelity(),
    sde=sde,
    denoiser=denoiser,
    solver=solver,
    dtype=dtype,
    device=device,
    verbose=True,
)
x, trajectory = model(
    y=None,
    physics=None,
    x_init=(1, 3, 64, 64),
    seed=10,
    get_trajectory=True,
    denoise_output=True,  # We set this to True to perform an additional denoising step at the end of the sampling process, which can improve the sample quality when the diffusion term is large at the end of the sampling process.
)

dinv.utils.plot(
    x,
    titles="Unconditional generation",
    save_fn="sde_sample.png",
    figsize=(figsize, figsize),
)
Unconditional generation
  0%|          | 0/149 [00:00<?, ?it/s]
  1%|          | 1/149 [00:00<00:26,  5.55it/s]
  1%|▏         | 2/149 [00:00<00:20,  7.24it/s]
  2%|▏         | 3/149 [00:00<00:18,  7.94it/s]
  3%|β–Ž         | 4/149 [00:00<00:18,  7.64it/s]
  3%|β–Ž         | 5/149 [00:00<00:18,  7.60it/s]
  4%|▍         | 6/149 [00:00<00:18,  7.71it/s]
  5%|▍         | 7/149 [00:00<00:17,  8.15it/s]
  5%|β–Œ         | 8/149 [00:01<00:16,  8.46it/s]
  6%|β–Œ         | 9/149 [00:01<00:16,  8.69it/s]
  7%|β–‹         | 10/149 [00:01<00:15,  8.84it/s]
  7%|β–‹         | 11/149 [00:01<00:15,  8.95it/s]
  8%|β–Š         | 12/149 [00:01<00:15,  9.03it/s]
  9%|β–Š         | 13/149 [00:01<00:14,  9.09it/s]
  9%|β–‰         | 14/149 [00:01<00:14,  9.06it/s]
 10%|β–ˆ         | 15/149 [00:01<00:14,  9.09it/s]
 11%|β–ˆ         | 16/149 [00:01<00:14,  9.12it/s]
 11%|β–ˆβ–        | 17/149 [00:01<00:14,  9.14it/s]
 12%|β–ˆβ–        | 18/149 [00:02<00:14,  9.16it/s]
 13%|β–ˆβ–Ž        | 19/149 [00:02<00:14,  9.18it/s]
 13%|β–ˆβ–Ž        | 20/149 [00:02<00:14,  9.19it/s]
 14%|β–ˆβ–        | 21/149 [00:02<00:13,  9.20it/s]
 15%|β–ˆβ–        | 22/149 [00:02<00:13,  9.21it/s]
 15%|β–ˆβ–Œ        | 23/149 [00:02<00:13,  9.22it/s]
 16%|β–ˆβ–Œ        | 24/149 [00:02<00:13,  9.21it/s]
 17%|β–ˆβ–‹        | 25/149 [00:02<00:14,  8.61it/s]
 17%|β–ˆβ–‹        | 26/149 [00:03<00:15,  8.04it/s]
 18%|β–ˆβ–Š        | 27/149 [00:03<00:15,  7.88it/s]
 19%|β–ˆβ–‰        | 28/149 [00:03<00:14,  8.24it/s]
 19%|β–ˆβ–‰        | 29/149 [00:03<00:14,  8.51it/s]
 20%|β–ˆβ–ˆ        | 30/149 [00:03<00:13,  8.70it/s]
 21%|β–ˆβ–ˆ        | 31/149 [00:03<00:13,  8.85it/s]
 21%|β–ˆβ–ˆβ–       | 32/149 [00:03<00:13,  8.96it/s]
 22%|β–ˆβ–ˆβ–       | 33/149 [00:03<00:12,  9.05it/s]
 23%|β–ˆβ–ˆβ–Ž       | 34/149 [00:03<00:12,  9.10it/s]
 23%|β–ˆβ–ˆβ–Ž       | 35/149 [00:04<00:12,  9.14it/s]
 24%|β–ˆβ–ˆβ–       | 36/149 [00:04<00:12,  9.17it/s]
 25%|β–ˆβ–ˆβ–       | 37/149 [00:04<00:12,  9.18it/s]
 26%|β–ˆβ–ˆβ–Œ       | 38/149 [00:04<00:12,  9.19it/s]
 26%|β–ˆβ–ˆβ–Œ       | 39/149 [00:04<00:11,  9.21it/s]
 27%|β–ˆβ–ˆβ–‹       | 40/149 [00:04<00:11,  9.22it/s]
 28%|β–ˆβ–ˆβ–Š       | 41/149 [00:04<00:11,  9.22it/s]
 28%|β–ˆβ–ˆβ–Š       | 42/149 [00:04<00:11,  9.21it/s]
 29%|β–ˆβ–ˆβ–‰       | 43/149 [00:04<00:11,  9.22it/s]
 30%|β–ˆβ–ˆβ–‰       | 44/149 [00:05<00:11,  9.22it/s]
 30%|β–ˆβ–ˆβ–ˆ       | 45/149 [00:05<00:11,  9.22it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 46/149 [00:05<00:11,  8.71it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 47/149 [00:05<00:12,  8.31it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 48/149 [00:05<00:12,  7.94it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 49/149 [00:05<00:12,  8.27it/s]
 34%|β–ˆβ–ˆβ–ˆβ–Ž      | 50/149 [00:05<00:11,  8.54it/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 51/149 [00:05<00:11,  8.73it/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 52/149 [00:05<00:10,  8.87it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 53/149 [00:06<00:10,  8.97it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 54/149 [00:06<00:10,  9.05it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 55/149 [00:06<00:10,  9.10it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 56/149 [00:06<00:10,  9.14it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 57/149 [00:06<00:10,  9.16it/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 58/149 [00:06<00:09,  9.17it/s]
 40%|β–ˆβ–ˆβ–ˆβ–‰      | 59/149 [00:06<00:09,  9.19it/s]
 40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 60/149 [00:06<00:09,  9.18it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 61/149 [00:06<00:09,  9.19it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 62/149 [00:07<00:09,  9.20it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 63/149 [00:07<00:09,  9.21it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 64/149 [00:07<00:09,  9.21it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 65/149 [00:07<00:09,  9.21it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 66/149 [00:07<00:09,  9.21it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 67/149 [00:07<00:09,  8.74it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 68/149 [00:07<00:09,  8.39it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 69/149 [00:07<00:10,  7.96it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 70/149 [00:07<00:09,  8.17it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 71/149 [00:08<00:09,  8.46it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 72/149 [00:08<00:08,  8.67it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 73/149 [00:08<00:08,  8.83it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 74/149 [00:08<00:08,  8.95it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 75/149 [00:08<00:08,  9.03it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 76/149 [00:08<00:08,  9.09it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 77/149 [00:08<00:07,  9.13it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 78/149 [00:08<00:07,  9.15it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 79/149 [00:08<00:07,  9.17it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 80/149 [00:09<00:07,  9.19it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 81/149 [00:09<00:07,  9.19it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 82/149 [00:09<00:07,  9.20it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 83/149 [00:09<00:07,  9.21it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 84/149 [00:09<00:07,  9.20it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 85/149 [00:09<00:06,  9.21it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 86/149 [00:09<00:06,  9.21it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 87/149 [00:09<00:06,  9.21it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 88/149 [00:09<00:06,  8.89it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 89/149 [00:10<00:07,  8.46it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 90/149 [00:10<00:07,  8.09it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 91/149 [00:10<00:07,  8.18it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 92/149 [00:10<00:06,  8.46it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 93/149 [00:10<00:06,  8.67it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 94/149 [00:10<00:06,  8.82it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 95/149 [00:10<00:06,  8.92it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 96/149 [00:10<00:05,  9.00it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 97/149 [00:10<00:05,  9.06it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 98/149 [00:11<00:05,  9.10it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 99/149 [00:11<00:05,  9.14it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 100/149 [00:11<00:05,  9.16it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 101/149 [00:11<00:05,  9.18it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 102/149 [00:11<00:05,  9.19it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 103/149 [00:11<00:04,  9.20it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 104/149 [00:11<00:04,  9.21it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 105/149 [00:11<00:04,  9.21it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 106/149 [00:11<00:04,  9.21it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 107/149 [00:12<00:04,  9.21it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 108/149 [00:12<00:04,  9.21it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 109/149 [00:12<00:04,  9.11it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 110/149 [00:12<00:04,  8.49it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 111/149 [00:12<00:04,  8.14it/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 112/149 [00:12<00:04,  8.13it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 113/149 [00:12<00:04,  8.43it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 114/149 [00:12<00:04,  8.65it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 115/149 [00:13<00:03,  8.81it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 116/149 [00:13<00:03,  8.93it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 117/149 [00:13<00:03,  9.02it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 118/149 [00:13<00:03,  9.08it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 119/149 [00:13<00:03,  9.12it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 120/149 [00:13<00:03,  9.15it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 121/149 [00:13<00:03,  9.17it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 122/149 [00:13<00:02,  9.19it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 123/149 [00:13<00:02,  9.20it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 124/149 [00:13<00:02,  9.20it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 125/149 [00:14<00:02,  9.21it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 126/149 [00:14<00:02,  9.22it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 127/149 [00:14<00:02,  9.22it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 128/149 [00:14<00:02,  9.21it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 129/149 [00:14<00:02,  9.22it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 130/149 [00:14<00:02,  9.22it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 131/149 [00:14<00:02,  8.57it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 132/149 [00:14<00:02,  8.19it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 133/149 [00:15<00:01,  8.02it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 134/149 [00:15<00:01,  8.34it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 135/149 [00:15<00:01,  8.58it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 136/149 [00:15<00:01,  8.77it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 137/149 [00:15<00:01,  8.90it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 138/149 [00:15<00:01,  9.00it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 139/149 [00:15<00:01,  9.07it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 140/149 [00:15<00:00,  9.11it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 141/149 [00:15<00:00,  9.13it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 142/149 [00:16<00:00,  9.16it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 143/149 [00:16<00:00,  9.18it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 144/149 [00:16<00:00,  9.19it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 145/149 [00:16<00:00,  9.19it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 146/149 [00:16<00:00,  9.20it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 147/149 [00:16<00:00,  9.20it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 148/149 [00:16<00:00,  9.21it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:16<00:00,  9.21it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:16<00:00,  8.88it/s]

We can also save the trajectory of the sample

anim = dinv.utils.plot_videos(
    trajectory.cpu()[::gif_frequency],
    time_dim=0,
    titles=["VE-SDE Trajectory"],
    figsize=(figsize, figsize),
    return_anim=True,
)
anim


When the data fidelity is given, together with the measurements and the physics, this class can be used to perform posterior sampling for inverse problems. For example, consider the inpainting problem, where we have a noisy image and we want to recover the original image. We can use the deepinv.sampling.DPSDataFidelity as the data fidelity term.

mask = torch.ones_like(x)
mask[..., 24:40, 24:40] = 0.0
physics = dinv.physics.Inpainting(img_size=x.shape[1:], mask=mask, device=device)
y = physics(x)

weight = 4.0  # guidance strength
dps_fidelity = DPSDataFidelity(denoiser=denoiser, weight=weight)

model = PosteriorDiffusion(
    data_fidelity=dps_fidelity,
    denoiser=denoiser,
    sde=sde,
    solver=solver,
    dtype=dtype,
    device=device,
    verbose=True,
)

# To perform posterior sampling, we need to provide the measurements, the physics and the solver.
# Moreover, when the physics is given, the initial point can be inferred from the physics if not given explicitly.

seed_1 = 11

x_hat, trajectory = model(
    y,
    physics,
    seed=seed_1,
    get_trajectory=True,
    denoise_output=True,  # We set this to True to perform an additional denoising step at the end
)

# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
    [x, y, x_hat],
    show=True,
    titles=["Original", "Measurement", "Posterior sample"],
    figsize=(figsize * 3, figsize),
)
Original, Measurement, Posterior sample
  0%|          | 0/149 [00:00<?, ?it/s]
  1%|          | 1/149 [00:00<00:36,  4.04it/s]
  1%|▏         | 2/149 [00:00<00:34,  4.31it/s]
  2%|▏         | 3/149 [00:00<00:33,  4.40it/s]
  3%|β–Ž         | 4/149 [00:00<00:34,  4.20it/s]
  3%|β–Ž         | 5/149 [00:01<00:36,  3.91it/s]
  4%|▍         | 6/149 [00:01<00:34,  4.10it/s]
  5%|▍         | 7/149 [00:01<00:33,  4.23it/s]
  5%|β–Œ         | 8/149 [00:01<00:32,  4.32it/s]
  6%|β–Œ         | 9/149 [00:02<00:31,  4.39it/s]
  7%|β–‹         | 10/149 [00:02<00:31,  4.43it/s]
  7%|β–‹         | 11/149 [00:02<00:30,  4.46it/s]
  8%|β–Š         | 12/149 [00:02<00:30,  4.48it/s]
  9%|β–Š         | 13/149 [00:02<00:30,  4.49it/s]
  9%|β–‰         | 14/149 [00:03<00:30,  4.50it/s]
 10%|β–ˆ         | 15/149 [00:03<00:31,  4.25it/s]
 11%|β–ˆ         | 16/149 [00:03<00:31,  4.18it/s]
 11%|β–ˆβ–        | 17/149 [00:03<00:30,  4.27it/s]
 12%|β–ˆβ–        | 18/149 [00:04<00:30,  4.34it/s]
 13%|β–ˆβ–Ž        | 19/149 [00:04<00:29,  4.39it/s]
 13%|β–ˆβ–Ž        | 20/149 [00:04<00:29,  4.43it/s]
 14%|β–ˆβ–        | 21/149 [00:04<00:28,  4.45it/s]
 15%|β–ˆβ–        | 22/149 [00:05<00:28,  4.45it/s]
 15%|β–ˆβ–Œ        | 23/149 [00:05<00:28,  4.46it/s]
 16%|β–ˆβ–Œ        | 24/149 [00:05<00:27,  4.47it/s]
 17%|β–ˆβ–‹        | 25/149 [00:05<00:28,  4.33it/s]
 17%|β–ˆβ–‹        | 26/149 [00:06<00:29,  4.14it/s]
 18%|β–ˆβ–Š        | 27/149 [00:06<00:28,  4.24it/s]
 19%|β–ˆβ–‰        | 28/149 [00:06<00:28,  4.29it/s]
 19%|β–ˆβ–‰        | 29/149 [00:06<00:27,  4.36it/s]
 20%|β–ˆβ–ˆ        | 30/149 [00:06<00:27,  4.40it/s]
 21%|β–ˆβ–ˆ        | 31/149 [00:07<00:26,  4.44it/s]
 21%|β–ˆβ–ˆβ–       | 32/149 [00:07<00:26,  4.47it/s]
 22%|β–ˆβ–ˆβ–       | 33/149 [00:07<00:25,  4.48it/s]
 23%|β–ˆβ–ˆβ–Ž       | 34/149 [00:07<00:25,  4.50it/s]
 23%|β–ˆβ–ˆβ–Ž       | 35/149 [00:08<00:25,  4.47it/s]
 24%|β–ˆβ–ˆβ–       | 36/149 [00:08<00:26,  4.22it/s]
 25%|β–ˆβ–ˆβ–       | 37/149 [00:08<00:26,  4.24it/s]
 26%|β–ˆβ–ˆβ–Œ       | 38/149 [00:08<00:25,  4.32it/s]
 26%|β–ˆβ–ˆβ–Œ       | 39/149 [00:08<00:25,  4.38it/s]
 27%|β–ˆβ–ˆβ–‹       | 40/149 [00:09<00:24,  4.42it/s]
 28%|β–ˆβ–ˆβ–Š       | 41/149 [00:09<00:24,  4.45it/s]
 28%|β–ˆβ–ˆβ–Š       | 42/149 [00:09<00:23,  4.47it/s]
 29%|β–ˆβ–ˆβ–‰       | 43/149 [00:09<00:23,  4.48it/s]
 30%|β–ˆβ–ˆβ–‰       | 44/149 [00:10<00:23,  4.49it/s]
 30%|β–ˆβ–ˆβ–ˆ       | 45/149 [00:10<00:23,  4.50it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 46/149 [00:10<00:23,  4.30it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 47/149 [00:10<00:24,  4.13it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 48/149 [00:11<00:23,  4.24it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 49/149 [00:11<00:23,  4.33it/s]
 34%|β–ˆβ–ˆβ–ˆβ–Ž      | 50/149 [00:11<00:22,  4.39it/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 51/149 [00:11<00:22,  4.43it/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 52/149 [00:11<00:21,  4.46it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 53/149 [00:12<00:21,  4.47it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 54/149 [00:12<00:21,  4.49it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 55/149 [00:12<00:20,  4.50it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 56/149 [00:12<00:20,  4.44it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 57/149 [00:13<00:22,  4.18it/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 58/149 [00:13<00:21,  4.22it/s]
 40%|β–ˆβ–ˆβ–ˆβ–‰      | 59/149 [00:13<00:20,  4.31it/s]
 40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 60/149 [00:13<00:20,  4.37it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 61/149 [00:13<00:19,  4.41it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 62/149 [00:14<00:19,  4.44it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 63/149 [00:14<00:19,  4.47it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 64/149 [00:14<00:19,  4.47it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 65/149 [00:14<00:18,  4.48it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 66/149 [00:15<00:18,  4.50it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 67/149 [00:15<00:19,  4.26it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 68/149 [00:15<00:19,  4.16it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 69/149 [00:15<00:18,  4.26it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 70/149 [00:16<00:18,  4.25it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 71/149 [00:16<00:18,  4.32it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 72/149 [00:16<00:17,  4.38it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 73/149 [00:16<00:17,  4.43it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 74/149 [00:16<00:16,  4.46it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 75/149 [00:17<00:16,  4.48it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 76/149 [00:17<00:16,  4.49it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 77/149 [00:17<00:16,  4.36it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 78/149 [00:17<00:17,  4.13it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 79/149 [00:18<00:16,  4.23it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 80/149 [00:18<00:15,  4.32it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 81/149 [00:18<00:15,  4.38it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 82/149 [00:18<00:15,  4.42it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 83/149 [00:19<00:14,  4.45it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 84/149 [00:19<00:14,  4.47it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 85/149 [00:19<00:14,  4.49it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 86/149 [00:19<00:14,  4.49it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 87/149 [00:19<00:13,  4.49it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 88/149 [00:20<00:14,  4.24it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 89/149 [00:20<00:14,  4.22it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 90/149 [00:20<00:13,  4.30it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 91/149 [00:20<00:13,  4.36it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 92/149 [00:21<00:12,  4.41it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 93/149 [00:21<00:12,  4.44it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 94/149 [00:21<00:12,  4.46it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 95/149 [00:21<00:12,  4.48it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 96/149 [00:21<00:11,  4.49it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 97/149 [00:22<00:11,  4.50it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 98/149 [00:22<00:11,  4.32it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 99/149 [00:22<00:12,  4.14it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 100/149 [00:22<00:11,  4.24it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 101/149 [00:23<00:11,  4.32it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 102/149 [00:23<00:10,  4.38it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 103/149 [00:23<00:10,  4.42it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 104/149 [00:23<00:10,  4.45it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 105/149 [00:24<00:09,  4.47it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 106/149 [00:24<00:09,  4.48it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 107/149 [00:24<00:09,  4.49it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 108/149 [00:24<00:09,  4.45it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 109/149 [00:24<00:09,  4.18it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 110/149 [00:25<00:09,  4.20it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 111/149 [00:25<00:08,  4.29it/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 112/149 [00:25<00:08,  4.36it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 113/149 [00:25<00:08,  4.41it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 114/149 [00:26<00:07,  4.44it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 115/149 [00:26<00:07,  4.46it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 116/149 [00:26<00:07,  4.47it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 117/149 [00:26<00:07,  4.49it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 118/149 [00:26<00:06,  4.50it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 119/149 [00:27<00:07,  4.27it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 120/149 [00:27<00:07,  4.14it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 121/149 [00:27<00:06,  4.24it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 122/149 [00:27<00:06,  4.32it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 123/149 [00:28<00:05,  4.37it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 124/149 [00:28<00:05,  4.42it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 125/149 [00:28<00:05,  4.45it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 126/149 [00:28<00:05,  4.46it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 127/149 [00:29<00:04,  4.47it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 128/149 [00:29<00:04,  4.49it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 129/149 [00:29<00:04,  4.41it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 130/149 [00:29<00:04,  4.18it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 131/149 [00:30<00:04,  4.25it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 132/149 [00:30<00:03,  4.32it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 133/149 [00:30<00:03,  4.38it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 134/149 [00:30<00:03,  4.42it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 135/149 [00:30<00:03,  4.45it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 136/149 [00:31<00:02,  4.48it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 137/149 [00:31<00:02,  4.49it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 138/149 [00:31<00:02,  4.49it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 139/149 [00:31<00:02,  4.50it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 140/149 [00:32<00:02,  4.25it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 141/149 [00:32<00:01,  4.18it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 142/149 [00:32<00:01,  4.28it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 143/149 [00:32<00:01,  4.34it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 144/149 [00:32<00:01,  4.40it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 145/149 [00:33<00:00,  4.43it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 146/149 [00:33<00:00,  4.46it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 147/149 [00:33<00:00,  4.48it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 148/149 [00:33<00:00,  4.49it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:34<00:00,  4.49it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:34<00:00,  4.37it/s]

We can also save the trajectory of the posterior sample

anim = dinv.utils.plot_videos(
    trajectory[::gif_frequency],
    time_dim=0,
    titles=["Posterior sample with VE"],
    figsize=(figsize, figsize),
    return_anim=True,
)
anim


Note

Reproducibility: To ensure the reproducibility, if the parameter rng is given, the same sample will be generated when the same seed is used. One can obtain varying samples by using a different seed.

Parallel sampling: one can draw multiple samples in parallel by giving the initial shape, e.g., x_init = (B, C, H, W)

Varying the SDE#

One can also change the underlying SDE for sampling. For example, we can also use the Variance-Preserving (VP or DDPM) in deepinv.sampling.VariancePreservingDiffusion, whose forward drift and diffusion term are defined as:

\[f(x_t, t) = -\frac{1}{2} \beta(t)x_t \qquad \mbox{ and } \qquad g(t) = \beta(t) \qquad \mbox{ with } \beta(t) = \beta_{\mathrm{min}} + t \left( \beta_{\mathrm{max}} - \beta_{\mathrm{min}} \right).\]
del trajectory

sde = VariancePreservingDiffusion(alpha=0.01, device=device, dtype=dtype)
model = PosteriorDiffusion(
    data_fidelity=dps_fidelity,
    denoiser=denoiser,
    sde=sde,
    solver=solver,
    device=device,
    dtype=dtype,
    verbose=True,
)

x_hat_vp, trajectory = model(
    y,
    physics,
    seed=111,
    get_trajectory=True,
    denoise_output=True,  # We set this to True to perform an additional denoising step at the end
)
x_hat = x
dinv.utils.plot(
    [x_hat, x_hat_vp],
    titles=[
        "posterior sample with VE",
        "posterior sample with VP",
    ],
    figsize=(figsize * 2, figsize),
)
posterior sample with VE, posterior sample with VP
  0%|          | 0/149 [00:00<?, ?it/s]
  1%|          | 1/149 [00:00<00:32,  4.51it/s]
  1%|▏         | 2/149 [00:00<00:32,  4.51it/s]
  2%|▏         | 3/149 [00:00<00:32,  4.52it/s]
  3%|β–Ž         | 4/149 [00:00<00:32,  4.52it/s]
  3%|β–Ž         | 5/149 [00:01<00:32,  4.46it/s]
  4%|▍         | 6/149 [00:01<00:34,  4.18it/s]
  5%|▍         | 7/149 [00:01<00:33,  4.20it/s]
  5%|β–Œ         | 8/149 [00:01<00:32,  4.30it/s]
  6%|β–Œ         | 9/149 [00:02<00:32,  4.36it/s]
  7%|β–‹         | 10/149 [00:02<00:31,  4.42it/s]
  7%|β–‹         | 11/149 [00:02<00:30,  4.45it/s]
  8%|β–Š         | 12/149 [00:02<00:30,  4.47it/s]
  9%|β–Š         | 13/149 [00:02<00:30,  4.49it/s]
  9%|β–‰         | 14/149 [00:03<00:29,  4.50it/s]
 10%|β–ˆ         | 15/149 [00:03<00:29,  4.51it/s]
 11%|β–ˆ         | 16/149 [00:03<00:31,  4.27it/s]
 11%|β–ˆβ–        | 17/149 [00:03<00:31,  4.14it/s]
 12%|β–ˆβ–        | 18/149 [00:04<00:30,  4.25it/s]
 13%|β–ˆβ–Ž        | 19/149 [00:04<00:29,  4.33it/s]
 13%|β–ˆβ–Ž        | 20/149 [00:04<00:29,  4.39it/s]
 14%|β–ˆβ–        | 21/149 [00:04<00:28,  4.44it/s]
 15%|β–ˆβ–        | 22/149 [00:05<00:28,  4.46it/s]
 15%|β–ˆβ–Œ        | 23/149 [00:05<00:28,  4.48it/s]
 16%|β–ˆβ–Œ        | 24/149 [00:05<00:27,  4.49it/s]
 17%|β–ˆβ–‹        | 25/149 [00:05<00:27,  4.49it/s]
 17%|β–ˆβ–‹        | 26/149 [00:05<00:28,  4.29it/s]
 18%|β–ˆβ–Š        | 27/149 [00:06<00:30,  4.04it/s]
 19%|β–ˆβ–‰        | 28/149 [00:06<00:28,  4.17it/s]
 19%|β–ˆβ–‰        | 29/149 [00:06<00:28,  4.28it/s]
 20%|β–ˆβ–ˆ        | 30/149 [00:06<00:27,  4.34it/s]
 21%|β–ˆβ–ˆ        | 31/149 [00:07<00:26,  4.40it/s]
 21%|β–ˆβ–ˆβ–       | 32/149 [00:07<00:26,  4.44it/s]
 22%|β–ˆβ–ˆβ–       | 33/149 [00:07<00:25,  4.47it/s]
 23%|β–ˆβ–ˆβ–Ž       | 34/149 [00:07<00:25,  4.49it/s]
 23%|β–ˆβ–ˆβ–Ž       | 35/149 [00:07<00:25,  4.50it/s]
 24%|β–ˆβ–ˆβ–       | 36/149 [00:08<00:25,  4.51it/s]
 25%|β–ˆβ–ˆβ–       | 37/149 [00:08<00:26,  4.28it/s]
 26%|β–ˆβ–ˆβ–Œ       | 38/149 [00:08<00:26,  4.21it/s]
 26%|β–ˆβ–ˆβ–Œ       | 39/149 [00:08<00:25,  4.31it/s]
 27%|β–ˆβ–ˆβ–‹       | 40/149 [00:09<00:24,  4.37it/s]
 28%|β–ˆβ–ˆβ–Š       | 41/149 [00:09<00:24,  4.42it/s]
 28%|β–ˆβ–ˆβ–Š       | 42/149 [00:09<00:24,  4.44it/s]
 29%|β–ˆβ–ˆβ–‰       | 43/149 [00:09<00:23,  4.47it/s]
 30%|β–ˆβ–ˆβ–‰       | 44/149 [00:10<00:23,  4.49it/s]
 30%|β–ˆβ–ˆβ–ˆ       | 45/149 [00:10<00:23,  4.50it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 46/149 [00:10<00:22,  4.51it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 47/149 [00:10<00:23,  4.37it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 48/149 [00:10<00:24,  4.14it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 49/149 [00:11<00:23,  4.25it/s]
 34%|β–ˆβ–ˆβ–ˆβ–Ž      | 50/149 [00:11<00:22,  4.33it/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 51/149 [00:11<00:22,  4.39it/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 52/149 [00:11<00:21,  4.43it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 53/149 [00:12<00:21,  4.46it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 54/149 [00:12<00:21,  4.47it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 55/149 [00:12<00:20,  4.49it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 56/149 [00:12<00:20,  4.50it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 57/149 [00:12<00:20,  4.50it/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 58/149 [00:13<00:21,  4.23it/s]
 40%|β–ˆβ–ˆβ–ˆβ–‰      | 59/149 [00:13<00:21,  4.21it/s]
 40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 60/149 [00:13<00:20,  4.30it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 61/149 [00:13<00:20,  4.36it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 62/149 [00:14<00:19,  4.41it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 63/149 [00:14<00:19,  4.45it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 64/149 [00:14<00:19,  4.47it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 65/149 [00:14<00:18,  4.49it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 66/149 [00:15<00:18,  4.51it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 67/149 [00:15<00:18,  4.51it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 68/149 [00:15<00:18,  4.35it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 69/149 [00:15<00:19,  4.17it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 70/149 [00:15<00:18,  4.27it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 71/149 [00:16<00:17,  4.35it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 72/149 [00:16<00:17,  4.40it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 73/149 [00:16<00:17,  4.44it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 74/149 [00:16<00:16,  4.46it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 75/149 [00:17<00:16,  4.48it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 76/149 [00:17<00:16,  4.50it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 77/149 [00:17<00:15,  4.51it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 78/149 [00:17<00:15,  4.50it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 79/149 [00:18<00:16,  4.24it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 80/149 [00:18<00:16,  4.24it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 81/149 [00:18<00:15,  4.33it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 82/149 [00:18<00:15,  4.39it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 83/149 [00:18<00:14,  4.43it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 84/149 [00:19<00:14,  4.45it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 85/149 [00:19<00:14,  4.47it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 86/149 [00:19<00:14,  4.49it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 87/149 [00:19<00:13,  4.50it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 88/149 [00:20<00:13,  4.51it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 89/149 [00:20<00:13,  4.33it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 90/149 [00:20<00:14,  4.15it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 91/149 [00:20<00:13,  4.26it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 92/149 [00:20<00:13,  4.34it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 93/149 [00:21<00:12,  4.39it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 94/149 [00:21<00:12,  4.43it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 95/149 [00:21<00:12,  4.47it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 96/149 [00:21<00:11,  4.48it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 97/149 [00:22<00:11,  4.50it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 98/149 [00:22<00:11,  4.41it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 99/149 [00:22<00:11,  4.24it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 100/149 [00:22<00:12,  3.99it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 101/149 [00:23<00:12,  3.97it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 102/149 [00:23<00:11,  3.96it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 103/149 [00:23<00:11,  3.97it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 104/149 [00:23<00:10,  4.12it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 105/149 [00:24<00:10,  4.23it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 106/149 [00:24<00:09,  4.32it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 107/149 [00:24<00:09,  4.38it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 108/149 [00:24<00:09,  4.42it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 109/149 [00:24<00:09,  4.34it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 110/149 [00:25<00:09,  4.10it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 111/149 [00:25<00:09,  4.22it/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 112/149 [00:25<00:08,  4.30it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 113/149 [00:25<00:08,  4.36it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 114/149 [00:26<00:07,  4.41it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 115/149 [00:26<00:07,  4.44it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 116/149 [00:26<00:07,  4.46it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 117/149 [00:26<00:07,  4.48it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 118/149 [00:27<00:06,  4.49it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 119/149 [00:27<00:06,  4.50it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 120/149 [00:27<00:06,  4.26it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 121/149 [00:27<00:06,  4.21it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 122/149 [00:27<00:06,  4.31it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 123/149 [00:28<00:05,  4.36it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 124/149 [00:28<00:05,  4.41it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 125/149 [00:28<00:05,  4.44it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 126/149 [00:28<00:05,  4.46it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 127/149 [00:29<00:04,  4.48it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 128/149 [00:29<00:04,  4.49it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 129/149 [00:29<00:04,  4.49it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 130/149 [00:29<00:04,  4.33it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 131/149 [00:30<00:04,  4.14it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 132/149 [00:30<00:03,  4.25it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 133/149 [00:30<00:03,  4.32it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 134/149 [00:30<00:03,  4.38it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 135/149 [00:30<00:03,  4.42it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 136/149 [00:31<00:02,  4.45it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 137/149 [00:31<00:02,  4.47it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 138/149 [00:31<00:02,  4.49it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 139/149 [00:31<00:02,  4.50it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 140/149 [00:32<00:02,  4.48it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 141/149 [00:32<00:01,  4.24it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 142/149 [00:32<00:01,  4.23it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 143/149 [00:32<00:01,  4.31it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 144/149 [00:32<00:01,  4.37it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 145/149 [00:33<00:00,  4.41it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 146/149 [00:33<00:00,  4.44it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 147/149 [00:33<00:00,  4.46it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 148/149 [00:33<00:00,  4.43it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:34<00:00,  4.36it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 149/149 [00:34<00:00,  4.37it/s]

We can also save the trajectory of the posterior sample

anim = dinv.utils.plot_videos(
    trajectory[::gif_frequency],
    time_dim=0,
    titles=["Posterior sample with VP"],
    figsize=(figsize, figsize),
    return_anim=True,
)
anim


Plug-and-play Posterior Sampling with arbitrary denoisers#

The deepinv.sampling.PosteriorDiffusion class can be used together with any (well-trained) denoisers for posterior sampling. For example, we can use the deepinv.models.DRUNet for posterior sampling. We can also change the underlying SDE, for example change the sigma_max value.

del trajectory  # clean memory
sigma_max = 10.0
rng = torch.Generator(device)
dtype = torch.float32
timesteps = torch.linspace(1, 0.001, 250)
solver = EulerSolver(timesteps=timesteps, rng=rng)
denoiser = dinv.models.DRUNet(pretrained="download").to(device)

sde = VarianceExplodingDiffusion(
    sigma_max=sigma_max, alpha=0.75, device=device, dtype=dtype
)

x = dinv.utils.load_example(
    "butterfly.png",
    img_size=256,
    resize_mode="resize",
).to(device)

mask = torch.ones_like(x)
mask[..., 70:150, 120:180] = 0
physics = dinv.physics.Inpainting(
    mask=mask,
    img_size=x.shape[1:],
    device=device,
)

y = physics(x)
model = PosteriorDiffusion(
    data_fidelity=DPSDataFidelity(denoiser=denoiser, weight=0.3),
    denoiser=denoiser,
    sde=sde,
    solver=solver,
    dtype=dtype,
    device=device,
    verbose=True,
)

# To perform posterior sampling, we need to provide the measurements, the physics and the solver.
x_hat, trajectory = model(
    y=y,
    physics=physics,
    seed=1,
    get_trajectory=True,
    denoise_output=True,
)

# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
    [x, y, x_hat.clip(0, 1)],
    titles=["Original", "Measurement", "Posterior sample DRUNet"],
    figsize=(figsize * 3, figsize),
)
Original, Measurement, Posterior sample DRUNet
  0%|          | 0/249 [00:00<?, ?it/s]
  0%|          | 1/249 [00:00<01:55,  2.14it/s]
  1%|          | 2/249 [00:00<02:01,  2.03it/s]
  1%|          | 3/249 [00:01<02:04,  1.97it/s]
  2%|▏         | 4/249 [00:01<02:02,  2.00it/s]
  2%|▏         | 5/249 [00:02<02:00,  2.02it/s]
  2%|▏         | 6/249 [00:02<01:59,  2.03it/s]
  3%|β–Ž         | 7/249 [00:03<02:02,  1.97it/s]
  3%|β–Ž         | 8/249 [00:04<02:02,  1.97it/s]
  4%|β–Ž         | 9/249 [00:04<02:00,  1.99it/s]
  4%|▍         | 10/249 [00:05<01:59,  1.99it/s]
  4%|▍         | 11/249 [00:05<01:59,  1.99it/s]
  5%|▍         | 12/249 [00:06<02:03,  1.92it/s]
  5%|β–Œ         | 13/249 [00:06<02:01,  1.94it/s]
  6%|β–Œ         | 14/249 [00:07<02:00,  1.96it/s]
  6%|β–Œ         | 15/249 [00:07<01:58,  1.97it/s]
  6%|β–‹         | 16/249 [00:08<01:58,  1.96it/s]
  7%|β–‹         | 17/249 [00:08<02:02,  1.90it/s]
  7%|β–‹         | 18/249 [00:09<01:59,  1.93it/s]
  8%|β–Š         | 19/249 [00:09<01:58,  1.94it/s]
  8%|β–Š         | 20/249 [00:10<01:57,  1.95it/s]
  8%|β–Š         | 21/249 [00:10<01:58,  1.92it/s]
  9%|β–‰         | 22/249 [00:11<01:59,  1.90it/s]
  9%|β–‰         | 23/249 [00:11<01:57,  1.92it/s]
 10%|β–‰         | 24/249 [00:12<01:56,  1.93it/s]
 10%|β–ˆ         | 25/249 [00:12<01:55,  1.94it/s]
 10%|β–ˆ         | 26/249 [00:13<01:59,  1.87it/s]
 11%|β–ˆ         | 27/249 [00:13<01:56,  1.90it/s]
 11%|β–ˆ         | 28/249 [00:14<01:55,  1.91it/s]
 12%|β–ˆβ–        | 29/249 [00:14<01:54,  1.92it/s]
 12%|β–ˆβ–        | 30/249 [00:15<01:54,  1.91it/s]
 12%|β–ˆβ–        | 31/249 [00:15<01:56,  1.87it/s]
 13%|β–ˆβ–Ž        | 32/249 [00:16<01:54,  1.89it/s]
 13%|β–ˆβ–Ž        | 33/249 [00:17<01:53,  1.90it/s]
 14%|β–ˆβ–Ž        | 34/249 [00:17<01:52,  1.91it/s]
 14%|β–ˆβ–        | 35/249 [00:18<01:55,  1.86it/s]
 14%|β–ˆβ–        | 36/249 [00:18<01:53,  1.87it/s]
 15%|β–ˆβ–        | 37/249 [00:19<01:52,  1.88it/s]
 15%|β–ˆβ–Œ        | 38/249 [00:19<01:51,  1.89it/s]
 16%|β–ˆβ–Œ        | 39/249 [00:20<01:51,  1.89it/s]
 16%|β–ˆβ–Œ        | 40/249 [00:20<01:53,  1.84it/s]
 16%|β–ˆβ–‹        | 41/249 [00:21<01:51,  1.86it/s]
 17%|β–ˆβ–‹        | 42/249 [00:21<01:50,  1.87it/s]
 17%|β–ˆβ–‹        | 43/249 [00:22<01:49,  1.88it/s]
 18%|β–ˆβ–Š        | 44/249 [00:22<01:51,  1.84it/s]
 18%|β–ˆβ–Š        | 45/249 [00:23<01:50,  1.84it/s]
 18%|β–ˆβ–Š        | 46/249 [00:24<01:49,  1.85it/s]
 19%|β–ˆβ–‰        | 47/249 [00:24<01:47,  1.87it/s]
 19%|β–ˆβ–‰        | 48/249 [00:25<01:47,  1.88it/s]
 20%|β–ˆβ–‰        | 49/249 [00:25<01:50,  1.81it/s]
 20%|β–ˆβ–ˆ        | 50/249 [00:26<01:48,  1.84it/s]
 20%|β–ˆβ–ˆ        | 51/249 [00:26<01:46,  1.85it/s]
 21%|β–ˆβ–ˆ        | 52/249 [00:27<01:46,  1.86it/s]
 21%|β–ˆβ–ˆβ–       | 53/249 [00:27<01:48,  1.81it/s]
 22%|β–ˆβ–ˆβ–       | 54/249 [00:28<01:47,  1.82it/s]
 22%|β–ˆβ–ˆβ–       | 55/249 [00:28<01:45,  1.84it/s]
 22%|β–ˆβ–ˆβ–       | 56/249 [00:29<01:44,  1.85it/s]
 23%|β–ˆβ–ˆβ–Ž       | 57/249 [00:29<01:44,  1.84it/s]
 23%|β–ˆβ–ˆβ–Ž       | 58/249 [00:30<01:46,  1.80it/s]
 24%|β–ˆβ–ˆβ–Ž       | 59/249 [00:31<01:44,  1.82it/s]
 24%|β–ˆβ–ˆβ–       | 60/249 [00:31<01:42,  1.84it/s]
 24%|β–ˆβ–ˆβ–       | 61/249 [00:32<01:41,  1.85it/s]
 25%|β–ˆβ–ˆβ–       | 62/249 [00:32<01:44,  1.78it/s]
 25%|β–ˆβ–ˆβ–Œ       | 63/249 [00:33<01:43,  1.81it/s]
 26%|β–ˆβ–ˆβ–Œ       | 64/249 [00:33<01:41,  1.81it/s]
 26%|β–ˆβ–ˆβ–Œ       | 65/249 [00:34<01:41,  1.82it/s]
 27%|β–ˆβ–ˆβ–‹       | 66/249 [00:34<01:41,  1.80it/s]
 27%|β–ˆβ–ˆβ–‹       | 67/249 [00:35<01:42,  1.77it/s]
 27%|β–ˆβ–ˆβ–‹       | 68/249 [00:36<01:41,  1.79it/s]
 28%|β–ˆβ–ˆβ–Š       | 69/249 [00:36<01:39,  1.81it/s]
 28%|β–ˆβ–ˆβ–Š       | 70/249 [00:37<01:38,  1.82it/s]
 29%|β–ˆβ–ˆβ–Š       | 71/249 [00:37<01:41,  1.76it/s]
 29%|β–ˆβ–ˆβ–‰       | 72/249 [00:38<01:39,  1.78it/s]
 29%|β–ˆβ–ˆβ–‰       | 73/249 [00:38<01:37,  1.80it/s]
 30%|β–ˆβ–ˆβ–‰       | 74/249 [00:39<01:36,  1.81it/s]
 30%|β–ˆβ–ˆβ–ˆ       | 75/249 [00:40<01:38,  1.77it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 76/249 [00:40<01:37,  1.77it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 77/249 [00:41<01:36,  1.78it/s]
 31%|β–ˆβ–ˆβ–ˆβ–      | 78/249 [00:41<01:35,  1.79it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 79/249 [00:42<01:35,  1.78it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 80/249 [00:42<01:36,  1.75it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 81/249 [00:43<01:35,  1.77it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 82/249 [00:43<01:33,  1.78it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 83/249 [00:44<01:32,  1.79it/s]
 34%|β–ˆβ–ˆβ–ˆβ–Ž      | 84/249 [00:45<01:35,  1.73it/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 85/249 [00:45<01:33,  1.75it/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 86/249 [00:46<01:32,  1.77it/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 87/249 [00:46<01:31,  1.77it/s]
 35%|β–ˆβ–ˆβ–ˆβ–Œ      | 88/249 [00:47<01:33,  1.72it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 89/249 [00:48<01:31,  1.74it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 90/249 [00:48<01:30,  1.75it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 91/249 [00:49<01:29,  1.76it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 92/249 [00:49<01:30,  1.74it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 93/249 [00:50<01:30,  1.73it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 94/249 [00:50<01:28,  1.74it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 95/249 [00:51<01:27,  1.76it/s]
 39%|β–ˆβ–ˆβ–ˆβ–Š      | 96/249 [00:51<01:27,  1.75it/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 97/249 [00:52<01:28,  1.72it/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 98/249 [00:53<01:27,  1.73it/s]
 40%|β–ˆβ–ˆβ–ˆβ–‰      | 99/249 [00:53<01:25,  1.75it/s]
 40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 100/249 [00:54<01:24,  1.76it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 101/249 [00:54<01:29,  1.65it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 102/249 [00:55<01:27,  1.68it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 103/249 [00:56<01:25,  1.71it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 104/249 [00:56<01:23,  1.73it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 105/249 [00:57<01:25,  1.68it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 106/249 [00:57<01:23,  1.71it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 107/249 [00:58<01:22,  1.73it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 108/249 [00:59<01:20,  1.74it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 109/249 [00:59<01:21,  1.72it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 110/249 [01:00<01:21,  1.71it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 111/249 [01:00<01:19,  1.74it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 112/249 [01:01<01:18,  1.74it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 113/249 [01:01<01:18,  1.73it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 114/249 [01:02<01:19,  1.70it/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 115/249 [01:03<01:17,  1.72it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 116/249 [01:03<01:16,  1.73it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 117/249 [01:04<01:16,  1.73it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 118/249 [01:04<01:19,  1.65it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 119/249 [01:05<01:18,  1.65it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 120/249 [01:06<01:18,  1.64it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 121/249 [01:06<01:19,  1.61it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 122/249 [01:07<01:23,  1.53it/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 123/249 [01:08<01:21,  1.55it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 124/249 [01:08<01:18,  1.59it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 125/249 [01:09<01:15,  1.63it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 126/249 [01:09<01:16,  1.62it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 127/249 [01:10<01:13,  1.66it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 128/249 [01:11<01:11,  1.69it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 129/249 [01:11<01:09,  1.72it/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 130/249 [01:12<01:11,  1.67it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 131/249 [01:12<01:09,  1.70it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 132/249 [01:13<01:08,  1.72it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 133/249 [01:13<01:07,  1.73it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 134/249 [01:14<01:08,  1.67it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 135/249 [01:15<01:07,  1.70it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 136/249 [01:15<01:05,  1.72it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 137/249 [01:16<01:05,  1.72it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 138/249 [01:16<01:05,  1.69it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 139/249 [01:17<01:05,  1.68it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 140/249 [01:18<01:04,  1.70it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 141/249 [01:18<01:03,  1.69it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 142/249 [01:19<01:03,  1.68it/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 143/249 [01:19<01:03,  1.66it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 144/249 [01:20<01:02,  1.68it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 145/249 [01:21<01:01,  1.69it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 146/249 [01:21<01:01,  1.68it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 147/249 [01:22<01:03,  1.61it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 148/249 [01:22<01:01,  1.64it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 149/249 [01:23<00:59,  1.67it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 150/249 [01:24<00:58,  1.69it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 151/249 [01:24<00:59,  1.65it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 152/249 [01:25<00:58,  1.66it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 153/249 [01:25<00:57,  1.68it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 154/249 [01:26<00:56,  1.69it/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 155/249 [01:27<01:00,  1.55it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 156/249 [01:27<00:57,  1.61it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 157/249 [01:28<00:56,  1.64it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 158/249 [01:29<00:54,  1.67it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 159/249 [01:29<00:55,  1.63it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 160/249 [01:30<00:53,  1.66it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 161/249 [01:30<00:52,  1.68it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 162/249 [01:31<00:51,  1.69it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 163/249 [01:32<00:52,  1.65it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 164/249 [01:32<00:50,  1.67it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 165/249 [01:33<00:49,  1.68it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 166/249 [01:33<00:48,  1.70it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 167/249 [01:34<00:50,  1.63it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 168/249 [01:35<00:49,  1.63it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 169/249 [01:35<00:48,  1.65it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 170/249 [01:36<00:47,  1.67it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 171/249 [01:36<00:46,  1.66it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 172/249 [01:37<00:46,  1.65it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 173/249 [01:38<00:45,  1.68it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 174/249 [01:38<00:44,  1.69it/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 175/249 [01:39<00:44,  1.68it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 176/249 [01:39<00:45,  1.61it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 177/249 [01:40<00:43,  1.65it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 178/249 [01:41<00:42,  1.67it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 179/249 [01:41<00:42,  1.65it/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 180/249 [01:42<00:44,  1.56it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 181/249 [01:42<00:42,  1.61it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 182/249 [01:43<00:40,  1.64it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 183/249 [01:44<00:39,  1.66it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 184/249 [01:44<00:40,  1.59it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 185/249 [01:45<00:39,  1.63it/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 186/249 [01:45<00:38,  1.65it/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 187/249 [01:46<00:37,  1.67it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 188/249 [01:47<00:37,  1.62it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 189/249 [01:47<00:36,  1.65it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 190/249 [01:48<00:35,  1.67it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 191/249 [01:48<00:34,  1.69it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 192/249 [01:49<00:34,  1.64it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 193/249 [01:50<00:33,  1.66it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 194/249 [01:50<00:32,  1.68it/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 195/249 [01:51<00:32,  1.69it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 196/249 [01:52<00:32,  1.65it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 197/249 [01:52<00:31,  1.66it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 198/249 [01:53<00:30,  1.67it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 199/249 [01:53<00:29,  1.68it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 200/249 [01:54<00:29,  1.66it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 201/249 [01:55<00:29,  1.65it/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 202/249 [01:55<00:28,  1.67it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 203/249 [01:56<00:27,  1.68it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 204/249 [01:56<00:27,  1.66it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 205/249 [01:57<00:26,  1.64it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 206/249 [01:58<00:25,  1.66it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 207/249 [01:58<00:25,  1.67it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 208/249 [01:59<00:24,  1.65it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 209/249 [01:59<00:24,  1.63it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 210/249 [02:00<00:23,  1.65it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 211/249 [02:01<00:22,  1.67it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 212/249 [02:01<00:22,  1.66it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 213/249 [02:02<00:22,  1.64it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 214/249 [02:02<00:21,  1.66it/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 215/249 [02:03<00:20,  1.67it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 216/249 [02:04<00:19,  1.67it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 217/249 [02:04<00:19,  1.64it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 218/249 [02:05<00:18,  1.66it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 219/249 [02:05<00:18,  1.66it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 220/249 [02:06<00:17,  1.66it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 221/249 [02:07<00:17,  1.62it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 222/249 [02:07<00:16,  1.64it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 223/249 [02:08<00:15,  1.65it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 224/249 [02:08<00:15,  1.66it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 225/249 [02:09<00:14,  1.62it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 226/249 [02:10<00:14,  1.64it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 227/249 [02:10<00:13,  1.66it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 228/249 [02:11<00:12,  1.67it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 229/249 [02:11<00:12,  1.62it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 230/249 [02:12<00:11,  1.65it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 231/249 [02:13<00:10,  1.67it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 232/249 [02:13<00:10,  1.67it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 233/249 [02:14<00:09,  1.63it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 234/249 [02:14<00:09,  1.65it/s]
 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 235/249 [02:15<00:08,  1.66it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 236/249 [02:16<00:07,  1.67it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 237/249 [02:16<00:07,  1.63it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 238/249 [02:17<00:06,  1.65it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 239/249 [02:17<00:05,  1.67it/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 240/249 [02:18<00:05,  1.68it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 241/249 [02:19<00:05,  1.59it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 242/249 [02:19<00:04,  1.62it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 243/249 [02:20<00:03,  1.64it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 244/249 [02:21<00:03,  1.66it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 245/249 [02:21<00:02,  1.62it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 246/249 [02:22<00:01,  1.64it/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 247/249 [02:22<00:01,  1.66it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 248/249 [02:23<00:00,  1.67it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 249/249 [02:24<00:00,  1.59it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 249/249 [02:24<00:00,  1.73it/s]

We can also save the trajectory of the posterior sample

anim = dinv.utils.plot_videos(
    trajectory[::gif_frequency].clip(0, 1),
    time_dim=0,
    titles=["Posterior trajectory DRUNet"],
    figsize=(figsize, figsize),
    return_anim=True,
)
anim


Total running time of the script: (3 minutes 59.576 seconds)

Gallery generated by Sphinx-Gallery