.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/sampling/demo_ddrm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sampling_demo_ddrm.py: Image reconstruction with a diffusion model ==================================================================================================== This code shows you how to use the DDRM diffusion algorithm :footcite:t:`kawar2022denoising` to reconstruct images and also compute the uncertainty of a reconstruction from incomplete and noisy measurements. The DDRM method requires that: * The operator has a singular value decomposition (i.e., the operator is a :class:`deepinv.physics.DecomposablePhysics`). * The noise is Gaussian with known standard deviation (i.e., the noise model is :class:`deepinv.physics.GaussianNoise`). .. GENERATED FROM PYTHON SOURCE LINES 15-21 .. code-block:: Python import deepinv as dinv from deepinv.utils.plotting import plot import torch import numpy as np from deepinv.utils.demo import load_example .. GENERATED FROM PYTHON SOURCE LINES 22-26 Load example image from the internet -------------------------------------------------------------- This example uses an image of Messi. .. GENERATED FROM PYTHON SOURCE LINES 26-32 .. code-block:: Python device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" x = load_example("messi.jpg", img_size=32).to(device) .. GENERATED FROM PYTHON SOURCE LINES 33-37 Define forward operator and noise model -------------------------------------------------------------- We use image inpainting as the forward operator and Gaussian noise as the noise model. .. GENERATED FROM PYTHON SOURCE LINES 37-47 .. code-block:: Python sigma = 0.1 # noise level physics = dinv.physics.Inpainting( mask=0.5, img_size=x.shape[1:], device=device, noise_model=dinv.physics.GaussianNoise(sigma=sigma), ) .. GENERATED FROM PYTHON SOURCE LINES 48-53 Define the MMSE denoiser -------------------------------------------------------------- The diffusion method requires an MMSE denoiser that can be evaluated a various noise levels. Here we use a pretrained DRUNET denoiser from the :ref:`denoisers ` module. .. GENERATED FROM PYTHON SOURCE LINES 53-56 .. code-block:: Python denoiser = dinv.models.DRUNet(pretrained="download").to(device) .. GENERATED FROM PYTHON SOURCE LINES 57-64 Create the Monte Carlo sampler -------------------------------------------------------------- We can now reconstruct a noisy measurement using the diffusion method. We use the DDRM method from :class:`deepinv.sampling.DDRM`, which works with inverse problems that have a closed form singular value decomposition of the forward operator. The diffusion method requires a schedule of noise levels ``sigmas`` that are used to evaluate the denoiser. .. GENERATED FROM PYTHON SOURCE LINES 64-69 .. code-block:: Python sigmas = np.linspace(1, 0, 100) if torch.cuda.is_available() else np.linspace(1, 0, 10) diff = dinv.sampling.DDRM(denoiser=denoiser, etab=1.0, sigmas=sigmas, verbose=True) .. GENERATED FROM PYTHON SOURCE LINES 70-73 Generate the measurement --------------------------------------------------------------------------------- We apply the forward model to generate the noisy measurement. .. GENERATED FROM PYTHON SOURCE LINES 73-76 .. code-block:: Python y = physics(x) .. GENERATED FROM PYTHON SOURCE LINES 77-81 Run the diffusion algorithm and plot results --------------------------------------------------------------------------------- The diffusion algorithm returns a sample from the posterior distribution. We compare the posterior mean with a simple linear reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 81-96 .. code-block:: Python xhat = diff(y, physics) # compute linear inverse x_lin = physics.A_adjoint(y) # compute PSNR print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB") print(f"Diffusion PSNR: {dinv.metric.PSNR()(x, xhat).item():.2f} dB") # plot results error = (xhat - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error imgs = [x_lin, x, xhat] plot(imgs, titles=["measurement", "ground truth", "DDRM reconstruction"]) .. image-sg:: /auto_examples/sampling/images/sphx_glr_demo_ddrm_001.png :alt: measurement, ground truth, DDRM reconstruction :srcset: /auto_examples/sampling/images/sphx_glr_demo_ddrm_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/9 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_ddrm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_ddrm.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_