.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/sampling/demo_ddrm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sampling_demo_ddrm.py: Image reconstruction with a diffusion model ==================================================================================================== This code shows you how to use the DDRM diffusion algorithm to reconstruct images and also compute the uncertainty of a reconstruction from incomplete and noisy measurements. The paper can be found at https://arxiv.org/pdf/2209.11888.pdf. The DDRM method requires that: * The operator has a singular value decomposition (i.e., the operator is a :class:`deepinv.physics.DecomposablePhysics`). * The noise is Gaussian with known standard deviation (i.e., the noise model is :class:`deepinv.physics.GaussianNoise`). .. GENERATED FROM PYTHON SOURCE LINES 16-23 .. code-block:: Python import deepinv as dinv from deepinv.utils.plotting import plot import torch import numpy as np from deepinv.utils.demo import load_url_image .. GENERATED FROM PYTHON SOURCE LINES 24-28 Load example image from the internet -------------------------------------------------------------- This example uses an image of Lionel Messi from Wikipedia. .. GENERATED FROM PYTHON SOURCE LINES 28-38 .. code-block:: Python device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" url = ( "https://upload.wikimedia.org/wikipedia/commons/b/b4/" "Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg" ) x = load_url_image(url=url, img_size=32).to(device) .. GENERATED FROM PYTHON SOURCE LINES 39-43 Define forward operator and noise model -------------------------------------------------------------- We use image inpainting as the forward operator and Gaussian noise as the noise model. .. GENERATED FROM PYTHON SOURCE LINES 43-53 .. code-block:: Python sigma = 0.1 # noise level physics = dinv.physics.Inpainting( mask=0.5, tensor_size=x.shape[1:], device=device, noise_model=dinv.physics.GaussianNoise(sigma=sigma), ) .. GENERATED FROM PYTHON SOURCE LINES 54-59 Define the MMSE denoiser -------------------------------------------------------------- The diffusion method requires an MMSE denoiser that can be evaluated a various noise levels. Here we use a pretrained DRUNET denoiser from the :ref:`denoisers ` module. .. GENERATED FROM PYTHON SOURCE LINES 59-62 .. code-block:: Python denoiser = dinv.models.DRUNet(pretrained="download").to(device) .. GENERATED FROM PYTHON SOURCE LINES 63-70 Create the Monte Carlo sampler -------------------------------------------------------------- We can now reconstruct a noisy measurement using the diffusion method. We use the DDRM method from :class:`deepinv.sampling.DDRM`, which works with inverse problems that have a closed form singular value decomposition of the forward operator. The diffusion method requires a schedule of noise levels ``sigmas`` that are used to evaluate the denoiser. .. GENERATED FROM PYTHON SOURCE LINES 70-75 .. code-block:: Python sigmas = np.linspace(1, 0, 100) if torch.cuda.is_available() else np.linspace(1, 0, 10) diff = dinv.sampling.DDRM(denoiser=denoiser, etab=1.0, sigmas=sigmas, verbose=True) .. GENERATED FROM PYTHON SOURCE LINES 76-79 Generate the measurement --------------------------------------------------------------------------------- We apply the forward model to generate the noisy measurement. .. GENERATED FROM PYTHON SOURCE LINES 79-82 .. code-block:: Python y = physics(x) .. GENERATED FROM PYTHON SOURCE LINES 83-87 Run the diffusion algorithm and plot results --------------------------------------------------------------------------------- The diffusion algorithm returns a sample from the posterior distribution. We compare the posterior mean with a simple linear reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 87-102 .. code-block:: Python xhat = diff(y, physics) # compute linear inverse x_lin = physics.A_adjoint(y) # compute PSNR print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB") print(f"Diffusion PSNR: {dinv.metric.PSNR()(x, xhat).item():.2f} dB") # plot results error = (xhat - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error imgs = [x_lin, x, xhat] plot(imgs, titles=["measurement", "ground truth", "DDRM reconstruction"]) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_ddrm_001.png :alt: measurement, ground truth, DDRM reconstruction :srcset: /auto_examples/sampling/images/sphx_glr_demo_ddrm_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/9 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_ddrm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_ddrm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_