.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/sampling/demo_dps.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sampling_demo_dps.py: Implementing DPS ================ In this tutorial, we will go over the steps in the Diffusion Posterior Sampling (DPS) algorithm introduced in `Chung et al. `_ The full algorithm is implemented in :meth:`deepinv.sampling.DPS`. .. GENERATED FROM PYTHON SOURCE LINES 11-20 Installing dependencies ----------------------- Let us ``import`` the relevant packages, and load a sample image of size 64 x 64. This will be used as our ground truth image. .. note:: We work with an image of size 64 x 64 to reduce the computational time of this example. The algorithm works best with images of size 256 x 256. .. GENERATED FROM PYTHON SOURCE LINES 20-37 .. code-block:: Python import numpy as np import torch import deepinv as dinv from deepinv.utils.plotting import plot from deepinv.optim.data_fidelity import L2 from deepinv.utils.demo import load_url_image, get_image_url from tqdm import tqdm # to visualize progress device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" url = get_image_url("butterfly.png") x_true = load_url_image(url=url, img_size=64).to(device) x = x_true.clone() .. GENERATED FROM PYTHON SOURCE LINES 38-41 In this tutorial we consider random inpainting as the inverse problem, where the forward operator is implemented in :meth:`deepinv.physics.Inpainting`. In the example that we use, 90% of the pixels will be masked out randomly, and we will additionally have Additive White Gaussian Noise (AWGN) of standard deviation 12.75/255. .. GENERATED FROM PYTHON SOURCE LINES 41-60 .. code-block:: Python sigma = 12.75 / 255.0 # noise level physics = dinv.physics.Inpainting( tensor_size=(3, x.shape[-2], x.shape[-1]), mask=0.1, pixelwise=True, device=device, ) y = physics(x_true) imgs = [y, x_true] plot( imgs, titles=["measurement", "groundtruth"], ) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_001.png :alt: measurement, groundtruth :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 61-68 Diffusion model loading ----------------------- We will take a pre-trained diffusion model that was also used for the DiffPIR algorithm, namely the one trained on the FFHQ 256x256 dataset. Note that this means that the diffusion model was trained with human face images, which is very different from the image that we consider in our example. Nevertheless, we will see later on that ``DPS`` generalizes sufficiently well even in such case. .. GENERATED FROM PYTHON SOURCE LINES 68-72 .. code-block:: Python model = dinv.models.DiffUNet(large_model=False).to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/diffunet/resolve/main/diffusion_ffhq_10m.pt?download=true" to /home/runner/.cache/torch/hub/checkpoints/diffusion_ffhq_10m.pt 0%| | 0.00/357M [00:00`_) .. math:: \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t|\mathbf{y}) \approx \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t) + \nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\widehat{\mathbf{x}}_{t}) Remarkably, we can now compute the latter term when we have Gaussian noise, as .. math:: \log p(\mathbf{y}|\hat{\mathbf{x}}_{t}) = -\frac{\|\mathbf{y} - A\widehat{\mathbf{x}}_{t}\|_2^2}{2\sigma_y^2}. Moreover, taking the gradient w.r.t. :math:`\mathbf{x}_t` can be performed through automatic differentiation. Let's see how this can be done in PyTorch. Note that when we are taking the gradient w.r.t. a tensor, we first have to enable the gradient computation by ``tensor.requires_grad_()`` .. note:: The diffPIR algorithm assumes that the images are in the range [-1, 1], whereas standard denoisers usually output images in the range [0, 1]. This is why we rescale the images before applying the steps. .. GENERATED FROM PYTHON SOURCE LINES 214-246 .. code-block:: Python x0 = x_true * 2.0 - 1.0 # [0, 1] -> [-1, 1] data_fidelity = L2() # xt ~ q(xt|x0) i = 200 # choose some arbitrary timestep t = (torch.ones(1) * i).to(device) at = compute_alpha(betas, t.long()) sigma_cur = (1 - at).sqrt() / at.sqrt() xt = x0 + sigma_cur * torch.randn_like(x0) # DPS with torch.enable_grad(): # Turn on gradient xt.requires_grad_() # normalize to [0,1], denoise, and rescale to [-1, 1] x0_t = model(xt / 2 + 0.5, sigma_cur / 2) * 2 - 1 # Log-likelihood ll = data_fidelity(x0_t, y, physics).sqrt().sum() # Take gradient w.r.t. xt grad_ll = torch.autograd.grad(outputs=ll, inputs=xt)[0] # Visualize imgs = [x0, xt, x0_t, grad_ll] plot( imgs, titles=["groundtruth", "noisy", "posterior mean", "gradient"], ) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_003.png :alt: groundtruth, noisy, posterior mean, gradient :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 247-273 DPS Algorithm ------------- As we visited all the key components of DPS, we are now ready to define the algorithm. For every denoising timestep, the algorithm iterates the following 1. Get :math:`\hat{\mathbf{x}}` using the denoiser network. 2. Compute :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)` through backpropagation. 3. Perform reverse diffusion sampling with DDPM(IM), corresponding to an update with :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)`. 4. Take a gradient step with :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)`. There are two caveats here. First, in the original work, DPS used DDPM ancestral sampling. As the `DDIM sampler `_ is a generalization of DDPM in a sense that it retrieves DDPM when :math:`\eta = 1.0`, here we consider DDIM sampling. One can freely choose the :math:`\eta` parameter here, but since we will consider 1000 neural function evaluations (NFEs), it is advisable to keep it :math:`\eta = 1.0`. Second, when taking the log-likelihood gradient step, the gradient is weighted so that the actual implementation is a static step size times the :math:`\ell_2` norm of the residual: .. math:: \nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_{t}(\mathbf{x}_t)) \simeq \rho \nabla_{\mathbf{x}_t} \|\mathbf{y} - \mathbf{A}\hat{\mathbf{x}}_{t}\|_2 With these in mind, let us solve the inverse problem with DPS! .. GENERATED FROM PYTHON SOURCE LINES 276-281 .. note:: We only use 200 steps to reduce the computational time of this example. As suggested by the authors of DPS, the algorithm works best with ``num_steps = 1000``. .. GENERATED FROM PYTHON SOURCE LINES 281-353 .. code-block:: Python num_steps = 200 skip = num_train_timesteps // num_steps batch_size = 1 eta = 1.0 seq = range(0, num_train_timesteps, skip) seq_next = [-1] + list(seq[:-1]) time_pairs = list(zip(reversed(seq), reversed(seq_next))) # measurement x0 = x_true * 2.0 - 1.0 # x0 = x_true.clone() y = physics(x0.to(device)) # initial sample from x_T x = torch.randn_like(x0) xs = [x] x0_preds = [] for i, j in tqdm(time_pairs): t = (torch.ones(batch_size) * i).to(device) next_t = (torch.ones(batch_size) * j).to(device) at = compute_alpha(betas, t.long()) at_next = compute_alpha(betas, next_t.long()) xt = xs[-1].to(device) with torch.enable_grad(): xt.requires_grad_() # 1. denoising step aux_x = xt / (2 * at.sqrt()) + 0.5 # renormalize in [0, 1] sigma_cur = (1 - at).sqrt() / at.sqrt() # sigma_t x0_t = 2 * model(aux_x, sigma_cur / 2) - 1 x0_t = torch.clip(x0_t, -1.0, 1.0) # optional # 2. likelihood gradient approximation l2_loss = data_fidelity(x0_t, y, physics).sqrt().sum() norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0] norm_grad = norm_grad.detach() sigma_tilde = ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() * eta c2 = ((1 - at_next) - sigma_tilde**2).sqrt() # 3. noise step epsilon = torch.randn_like(xt) # 4. DDPM(IM) step xt_next = ( (at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t + sigma_tilde * epsilon + c2 * xt / (1 - at).sqrt() - norm_grad ) x0_preds.append(x0_t.to("cpu")) xs.append(xt_next.to("cpu")) recon = xs[-1] # plot the results x = recon / 2 + 0.5 imgs = [y, x, x_true] plot(imgs, titles=["measurement", "model output", "groundtruth"]) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_dps_004.png :alt: measurement, model output, groundtruth :srcset: /auto_examples/sampling/images/sphx_glr_demo_dps_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/200 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_dps.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_dps.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_