.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/sampling/demo_dps.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sampling_demo_dps.py: Implementing DPS ================ In this tutorial, we will go over the steps in the Diffusion Posterior Sampling (DPS) algorithm introduced in :footcite:t:`chung2022diffusion`. The full algorithm is implemented in :class:`deepinv.sampling.DPS`. .. GENERATED FROM PYTHON SOURCE LINES 10-19 Installing dependencies ----------------------- Let us ``import`` the relevant packages, and load a sample image of size 64 x 64. This will be used as our ground truth image. .. note:: We work with an image of size 64 x 64 to reduce the computational time of this example. The DiffUNet we use in the algorithm works best with images of size 256 x 256. .. GENERATED FROM PYTHON SOURCE LINES 19-33 .. code-block:: Python import torch import deepinv as dinv from deepinv.utils.plotting import plot from deepinv.optim.data_fidelity import L2 from deepinv.utils.demo import load_example from tqdm import tqdm # to visualize progress device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" x_true = load_example("butterfly.png", img_size=64).to(device) x = x_true.clone() .. GENERATED FROM PYTHON SOURCE LINES 34-37 In this tutorial we consider random inpainting as the inverse problem, where the forward operator is implemented in :class:`deepinv.physics.Inpainting`. In the example that we use, 90% of the pixels will be masked out randomly, and we will additionally have Additive White Gaussian Noise (AWGN) of standard deviation 12.75/255. .. GENERATED FROM PYTHON SOURCE LINES 37-57 .. code-block:: Python sigma = 12.75 / 255.0 # noise level physics = dinv.physics.Inpainting( img_size=(3, x.shape[-2], x.shape[-1]), mask=0.1, pixelwise=True, device=device, ) y = physics(x_true) plot( { "Measurement": y, "Ground Truth": x_true, } ) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_001.png :alt: Measurement, Ground Truth :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 58-65 Diffusion model loading ----------------------- We will take a pre-trained diffusion model that was also used for the DiffPIR algorithm, namely the one trained on the FFHQ 256x256 dataset. Note that this means that the diffusion model was trained with human face images, which is very different from the image that we consider in our example. Nevertheless, we will see later on that ``DPS`` generalizes sufficiently well even in such case. .. GENERATED FROM PYTHON SOURCE LINES 65-69 .. code-block:: Python model = dinv.models.DiffUNet(large_model=False).to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/diffunet/resolve/main/diffusion_ffhq_10m.pt?download=true" to /home/runner/.cache/torch/hub/checkpoints/diffusion_ffhq_10m.pt 0%| | 0.00/357M [00:00 [-1, 1] data_fidelity = L2() # xt ~ q(xt|x0) t = 200 # choose some arbitrary timestep at = alphas[t] sigma_cur = (1 - at).sqrt() / at.sqrt() xt = x0 + sigma_cur * torch.randn_like(x0) # DPS with torch.enable_grad(): # Turn on gradient xt.requires_grad_() # normalize to [0, 1], denoise, and rescale to [-1, 1] x0_t = model(xt / 2 + 0.5, sigma_cur / 2) * 2 - 1 # Log-likelihood ll = data_fidelity(x0_t, y, physics).sqrt().sum() # Take gradient w.r.t. xt grad_ll = torch.autograd.grad(outputs=ll, inputs=xt)[0] # Visualize plot( { "Ground Truth": x0, "Noisy": xt, "Posterior Mean": x0_t, "Gradient": grad_ll, } ) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_003.png :alt: Ground Truth, Noisy, Posterior Mean, Gradient :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 231-257 DPS Algorithm ------------- As we visited all the key components of DPS, we are now ready to define the algorithm. For every denoising timestep, the algorithm iterates the following 1. Get :math:`\hat{\mathbf{x}}` using the denoiser network. 2. Compute :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)` through backpropagation. 3. Perform reverse diffusion sampling with DDPM(IM), corresponding to an update with :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)`. 4. Take a gradient step with :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)`. There are two caveats here. First, in the original work, DPS used DDPM ancestral sampling. As the DDIM sampler :footcite:t:`song2020denoising` is a generalization of DDPM in a sense that it retrieves DDPM when :math:`\eta = 1.0`, here we consider DDIM sampling. One can freely choose the :math:`\eta` parameter here, but since we will consider 1000 neural function evaluations (NFEs), it is advisable to keep it :math:`\eta = 1.0`. Second, when taking the log-likelihood gradient step, the gradient is weighted so that the actual implementation is a static step size times the :math:`\ell_2` norm of the residual: .. math:: \nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_{t}(\mathbf{x}_t)) \simeq \rho \nabla_{\mathbf{x}_t} \|\mathbf{y} - \mathbf{A}\hat{\mathbf{x}}_{t}\|_2 With these in mind, let us solve the inverse problem with DPS! .. GENERATED FROM PYTHON SOURCE LINES 260-265 .. note:: We only use 200 steps to reduce the computational time of this example. As suggested by the authors of DPS, the algorithm works best with ``num_steps = 1000``. .. GENERATED FROM PYTHON SOURCE LINES 265-338 .. code-block:: Python num_steps = 200 skip = num_train_timesteps // num_steps batch_size = 1 eta = 1.0 # DDPM scheme; use eta < 1 for DDIM # measurement x0 = x_true * 2.0 - 1.0 # x0 = x_true.clone() y = physics(x0.to(device)) # initial sample from x_T x = torch.randn_like(x0) xs = [x] x0_preds = [] for t in tqdm(reversed(range(0, num_train_timesteps, skip))): at = alphas[t] at_next = alphas[t - skip] if t - skip >= 0 else torch.tensor(1) # we cannot use bt = betas[t] if skip > 1: bt = 1 - at / at_next xt = xs[-1].to(device) with torch.enable_grad(): xt.requires_grad_() # 1. denoising step aux_x = xt / (2 * at.sqrt()) + 0.5 # renormalize in [0, 1] sigma_cur = (1 - at).sqrt() / at.sqrt() # sigma_t x0_t = 2 * model(aux_x, sigma_cur / 2) - 1 x0_t = torch.clip(x0_t, -1.0, 1.0) # optional # 2. likelihood gradient approximation l2_loss = data_fidelity(x0_t, y, physics).sqrt().sum() norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0] norm_grad = norm_grad.detach() sigma_tilde = (bt * (1 - at_next) / (1 - at)).sqrt() * eta c2 = ((1 - at_next) - sigma_tilde**2).sqrt() # 3. noise step epsilon = torch.randn_like(xt) # 4. DDIM(PM) step xt_next = ( (at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t + sigma_tilde * epsilon + c2 * xt / (1 - at).sqrt() - norm_grad ) x0_preds.append(x0_t.to("cpu")) xs.append(xt_next.to("cpu")) recon = xs[-1] # plot the results x = recon / 2 + 0.5 plot( { "Measurement": y, "Model Output": x, "Ground Truth": x_true, } ) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_004.png :alt: Measurement, Model Output, Ground Truth :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0it [00:00, ?it/s] 1it [00:00, 2.24it/s] 2it [00:00, 2.25it/s] 3it [00:01, 2.25it/s] 4it [00:01, 2.25it/s] 5it [00:02, 2.25it/s] 6it [00:02, 2.26it/s] 7it [00:03, 2.26it/s] 8it [00:03, 2.26it/s] 9it [00:03, 2.26it/s] 10it [00:04, 2.26it/s] 11it [00:04, 2.26it/s] 12it [00:05, 2.26it/s] 13it [00:05, 2.26it/s] 14it [00:06, 2.26it/s] 15it [00:06, 2.26it/s] 16it [00:07, 2.26it/s] 17it [00:07, 2.26it/s] 18it [00:07, 2.26it/s] 19it [00:08, 2.26it/s] 20it [00:08, 2.25it/s] 21it [00:09, 2.25it/s] 22it [00:09, 2.26it/s] 23it [00:10, 2.26it/s] 24it [00:10, 2.26it/s] 25it [00:11, 2.26it/s] 26it [00:11, 2.26it/s] 27it [00:11, 2.26it/s] 28it [00:12, 2.26it/s] 29it [00:12, 2.26it/s] 30it [00:13, 2.26it/s] 31it [00:13, 2.26it/s] 32it [00:14, 2.26it/s] 33it [00:14, 2.26it/s] 34it [00:15, 2.26it/s] 35it [00:15, 2.26it/s] 36it [00:15, 2.26it/s] 37it [00:16, 2.26it/s] 38it [00:16, 2.26it/s] 39it [00:17, 2.26it/s] 40it [00:17, 2.26it/s] 41it [00:18, 2.26it/s] 42it [00:18, 2.26it/s] 43it [00:19, 2.25it/s] 44it [00:19, 2.25it/s] 45it [00:19, 2.25it/s] 46it [00:20, 2.26it/s] 47it [00:20, 2.26it/s] 48it [00:21, 2.26it/s] 49it [00:21, 2.26it/s] 50it [00:22, 2.26it/s] 51it [00:22, 2.26it/s] 52it [00:23, 2.27it/s] 53it [00:23, 2.27it/s] 54it [00:23, 2.26it/s] 55it [00:24, 2.26it/s] 56it [00:24, 2.27it/s] 57it [00:25, 2.26it/s] 58it [00:25, 2.26it/s] 59it [00:26, 2.26it/s] 60it [00:26, 2.27it/s] 61it [00:27, 2.27it/s] 62it [00:27, 2.25it/s] 63it [00:27, 2.25it/s] 64it [00:28, 2.26it/s] 65it [00:28, 2.27it/s] 66it [00:29, 2.27it/s] 67it [00:29, 2.27it/s] 68it [00:30, 2.27it/s] 69it [00:30, 2.27it/s] 70it [00:30, 2.27it/s] 71it [00:31, 2.27it/s] 72it [00:31, 2.27it/s] 73it [00:32, 2.26it/s] 74it [00:32, 2.26it/s] 75it [00:33, 2.26it/s] 76it [00:33, 2.27it/s] 77it [00:34, 2.27it/s] 78it [00:34, 2.27it/s] 79it [00:34, 2.27it/s] 80it [00:35, 2.25it/s] 81it [00:35, 2.25it/s] 82it [00:36, 2.26it/s] 83it [00:36, 2.26it/s] 84it [00:37, 2.26it/s] 85it [00:37, 2.26it/s] 86it [00:38, 2.26it/s] 87it [00:38, 2.26it/s] 88it [00:38, 2.26it/s] 89it [00:39, 2.26it/s] 90it [00:39, 2.26it/s] 91it [00:40, 2.26it/s] 92it [00:40, 2.26it/s] 93it [00:41, 2.26it/s] 94it [00:41, 2.26it/s] 95it [00:42, 2.26it/s] 96it [00:42, 2.26it/s] 97it [00:42, 2.26it/s] 98it [00:43, 2.26it/s] 99it [00:43, 2.26it/s] 100it [00:44, 2.26it/s] 101it [00:44, 2.26it/s] 102it [00:45, 2.25it/s] 103it [00:45, 2.26it/s] 104it [00:46, 2.26it/s] 105it [00:46, 2.26it/s] 106it [00:46, 2.26it/s] 107it [00:47, 2.26it/s] 108it [00:47, 2.26it/s] 109it [00:48, 2.26it/s] 110it [00:48, 2.26it/s] 111it [00:49, 2.26it/s] 112it [00:49, 2.26it/s] 113it [00:49, 2.27it/s] 114it [00:50, 2.27it/s] 115it [00:50, 2.27it/s] 116it [00:51, 2.26it/s] 117it [00:51, 2.26it/s] 118it [00:52, 2.26it/s] 119it [00:52, 2.27it/s] 120it [00:53, 2.26it/s] 121it [00:53, 2.26it/s] 122it [00:53, 2.26it/s] 123it [00:54, 2.27it/s] 124it [00:54, 2.26it/s] 125it [00:55, 2.26it/s] 126it [00:55, 2.26it/s] 127it [00:56, 2.26it/s] 128it [00:56, 2.26it/s] 129it [00:57, 2.26it/s] 130it [00:57, 2.26it/s] 131it [00:57, 2.26it/s] 132it [00:58, 2.26it/s] 133it [00:58, 2.26it/s] 134it [00:59, 2.26it/s] 135it [00:59, 2.26it/s] 136it [01:00, 2.27it/s] 137it [01:00, 2.27it/s] 138it [01:01, 2.26it/s] 139it [01:01, 2.27it/s] 140it [01:01, 2.27it/s] 141it [01:02, 2.27it/s] 142it [01:02, 2.27it/s] 143it [01:03, 2.27it/s] 144it [01:03, 2.27it/s] 145it [01:04, 2.27it/s] 146it [01:04, 2.27it/s] 147it [01:05, 2.27it/s] 148it [01:05, 2.27it/s] 149it [01:05, 2.27it/s] 150it [01:06, 2.27it/s] 151it [01:06, 2.27it/s] 152it [01:07, 2.27it/s] 153it [01:07, 2.27it/s] 154it [01:08, 2.27it/s] 155it [01:08, 2.27it/s] 156it [01:08, 2.27it/s] 157it [01:09, 2.26it/s] 158it [01:09, 2.26it/s] 159it [01:10, 2.27it/s] 160it [01:10, 2.27it/s] 161it [01:11, 2.27it/s] 162it [01:11, 2.27it/s] 163it [01:12, 2.27it/s] 164it [01:12, 2.27it/s] 165it [01:12, 2.26it/s] 166it [01:13, 2.26it/s] 167it [01:13, 2.26it/s] 168it [01:14, 2.26it/s] 169it [01:14, 2.26it/s] 170it [01:15, 2.26it/s] 171it [01:15, 2.26it/s] 172it [01:16, 2.26it/s] 173it [01:16, 2.27it/s] 174it [01:16, 2.26it/s] 175it [01:17, 2.27it/s] 176it [01:17, 2.27it/s] 177it [01:18, 2.25it/s] 178it [01:18, 2.25it/s] 179it [01:19, 2.25it/s] 180it [01:19, 2.25it/s] 181it [01:20, 2.25it/s] 182it [01:20, 2.26it/s] 183it [01:20, 2.26it/s] 184it [01:21, 2.26it/s] 185it [01:21, 2.25it/s] 186it [01:22, 2.25it/s] 187it [01:22, 2.25it/s] 188it [01:23, 2.25it/s] 189it [01:23, 2.25it/s] 190it [01:24, 2.26it/s] 191it [01:24, 2.25it/s] 192it [01:24, 2.25it/s] 193it [01:25, 2.25it/s] 194it [01:25, 2.25it/s] 195it [01:26, 2.25it/s] 196it [01:26, 2.25it/s] 197it [01:27, 2.25it/s] 198it [01:27, 2.23it/s] 199it [01:28, 2.23it/s] 200it [01:28, 2.24it/s] 200it [01:28, 2.26it/s] .. GENERATED FROM PYTHON SOURCE LINES 339-349 Using DPS in your inverse problem --------------------------------- You can readily use this algorithm via the :class:`deepinv.sampling.DPS` class. :: y = physics(x) model = dinv.sampling.DPS(dinv.models.DiffUNet(), data_fidelity=dinv.optim.data_fidelity.L2()) xhat = model(y, physics) .. GENERATED FROM PYTHON SOURCE LINES 351-354 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 31.922 seconds) .. _sphx_glr_download_auto_examples_sampling_demo_dps.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_dps.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_dps.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_dps.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_