.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/optimization/demo_poisson_mlem.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_optimization_demo_poisson_mlem.py: Poisson Inverse Problems with Maximum-Likelihood Expectation-Maximization (MLEM) ==================================================================================================== This example demonstrates how to solve Poisson inverse problems using the **Maximum-Likelihood Expectation-Maximization (MLEM)** algorithm :footcite:t:`sheppMaximumLikelihoodReconstruction1982`, also known as the Richardson-Lucy algorithm in the deconvolution setting :footcite:t:`richardsonBayesianBasedIterativeMethod1972,lucyIterativeTechniqueRectification1974`. The Poisson observation model is: .. math:: y \sim \mathcal{P}\!\left(\frac{Ax}{\gamma}\right) where :math:`A` is a linear forward operator, :math:`x \geq 0` is the image to recover, :math:`\gamma > 0` is the gain parameter, and :math:`\mathcal{P}` denotes the Poisson distribution. The MLEM algorithm solves the associated maximum-likelihood problem: .. math:: \underset{x \geq 0}{\operatorname{min}} \,\, \sum_i \left((Ax)_i - y_i \log((Ax)_i)\right) using the following iterative update rule: .. math:: x^{k+1} = \frac{x^k}{A^\top \mathbf{1}} \odot A^\top\!\left(\frac{y}{Ax^k}\right) where :math:`\odot` denotes element-wise multiplication and the division is also element-wise. The MLEM algorithm is widely used in emission tomography such as Positron Emission Tomography (PET) and Single Photon Emission Computed Tomography (SPECT), where the Poisson noise model is a natural fit. We show three scenarios of increasing complexity: 1. **Deblurring** with MLEM (no prior) 2. **Deblurring** with MLEM and Total-Variation (TV) prior 3. **2D Computed Tomography (CT)** with MLEM and TV prior .. GENERATED FROM PYTHON SOURCE LINES 40-48 .. code-block:: Python import torch import deepinv as dinv from pathlib import Path from torchvision import transforms from deepinv.utils.demo import load_dataset, load_example from deepinv.utils.plotting import plot, plot_curves .. GENERATED FROM PYTHON SOURCE LINES 49-52 Setup ----------------------------------------------- Set paths, device and random seed for reproducibility. .. GENERATED FROM PYTHON SOURCE LINES 52-59 .. code-block:: Python BASE_DIR = Path(".") RESULTS_DIR = BASE_DIR / "results" torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 60-61 We use a single image from the Set3C dataset. .. GENERATED FROM PYTHON SOURCE LINES 61-69 .. code-block:: Python img_size = 128 if torch.cuda.is_available() else 64 val_transform = transforms.Compose( [transforms.CenterCrop(img_size), transforms.ToTensor()] ) dataset = load_dataset("set3c", transform=val_transform) x = dataset[0].unsqueeze(0).to(device) # ground-truth image .. GENERATED FROM PYTHON SOURCE LINES 70-75 Deblurring with MLEM without prior ----------------------------------------------- We create a Gaussian blur operator with Poisson noise. The MLEM/Richardson-Lucy algorithm is a standard approach for Poisson deconvolution without any prior. .. GENERATED FROM PYTHON SOURCE LINES 75-93 .. code-block:: Python # Define the blur kernel n_channels = 3 filter_torch = dinv.physics.blur.gaussian_blur(sigma=(2, 2)) gain = 1 / 100 physics_blur = dinv.physics.BlurFFT( img_size=(n_channels, img_size, img_size), filter=filter_torch, device=device, noise_model=dinv.physics.PoissonNoise( gain=gain, normalize=True, clip_positive=True ), ) # Generate noisy blurred observation y_blur = physics_blur(x) .. GENERATED FROM PYTHON SOURCE LINES 94-97 The :class:`deepinv.optim.MLEM` class wraps the MLEM iterations. Without a prior, and in the case of deconvolution, this is equivalent to the classic Richardson-Lucy algorithm. Note that without prior, the algorithm will create artifacts when noise is present in the observation. .. GENERATED FROM PYTHON SOURCE LINES 97-114 .. code-block:: Python data_fidelity = dinv.optim.PoissonLikelihood(gain=gain) model_no_prior = dinv.optim.MLEM( data_fidelity=data_fidelity, prior=None, max_iter=20, early_stop=True, thres_conv=1e-6, crit_conv="residual", verbose=True, ) x_mlem, metrics_mlem = model_no_prior( y_blur, physics_blur, x_gt=x, compute_metrics=True ) .. GENERATED FROM PYTHON SOURCE LINES 115-116 Visualize results and PSNR values along with convergence curves .. GENERATED FROM PYTHON SOURCE LINES 116-136 .. code-block:: Python psnr_input = dinv.metric.PSNR()(x, y_blur) psnr_mlem = dinv.metric.PSNR()(x, x_mlem) plot( { "Ground Truth": x, "Measurement": y_blur, "MLEM": x_mlem, }, subtitles=[ "Reference", f"PSNR: {psnr_input.item():.2f} dB", f"PSNR: {psnr_mlem.item():.2f} dB", ], figsize=(9, 3), ) plot_curves(metrics_mlem) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_001.png :alt: Ground Truth, Measurement, MLEM :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_002.png :alt: $\text{PSNR}(x_k)$, $F(x_k)$, Residual $\frac{||x_{k+1} - x_k||}{||x_k||}$ :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_002.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 137-157 Deblurring with MLEM + TV prior ---------------------------------------- As we saw, MLEM tends to amplify noise when no prior information is used. Adding a Total-Variation (TV) prior solves this issue while favoring piecewise constant solutions. There are several ways of modifying MLEM for regularized objectives: here we use the most straightforward approach called One-Step-Late (OSL) :footcite:`greenUseEmAlgorithm1990` which simply adds the gradient of the prior to the denominator of the MLEM update: .. math:: x^{k+1} = \frac{x^k}{A^\top \mathbf{1} + \lambda \nabla \regname(x^k)} \odot A^\top\!\left(\frac{y}{Ax^k}\right) For non-smooth regularizations, the penalized MLEM update becomes: .. math:: x^{k+1} = \frac{x^k}{A^\top \mathbf{1} + \lambda g^k} \odot A^\top\!\left(\frac{y}{Ax^k}\right), \quad g^k \in \partial \regname(x^k) where :math:`\partial \regname(x^k)` is the subdifferential of the regularization at :math:`x^k`. Any prior implementing the :class:`deepinv.optim.prior.Prior` interface can be used in the :class:`deepinv.optim.MLEM` class, and the proximal step is automatically computed when needed. .. GENERATED FROM PYTHON SOURCE LINES 157-175 .. code-block:: Python prior_tv = dinv.optim.prior.TVPrior(n_it_max=50) model_tv = dinv.optim.MLEM( data_fidelity=data_fidelity, prior=prior_tv, lambda_reg=0.02, max_iter=100, early_stop=True, thres_conv=1e-6, crit_conv="residual", verbose=True, ) x_mlem_tv, metrics_mlem_tv = model_tv( y_blur, physics_blur, x_gt=x, compute_metrics=True ) .. GENERATED FROM PYTHON SOURCE LINES 176-177 Visualize results — MLEM + TV .. GENERATED FROM PYTHON SOURCE LINES 177-198 .. code-block:: Python psnr_mlem_tv = dinv.metric.PSNR()(x, x_mlem_tv) plot( { "Ground Truth": x, "Measurement": y_blur, # "MLEM": x_mlem, "MLEM with TV": x_mlem_tv, }, subtitles=[ "Reference", f"PSNR: {psnr_input.item():.2f} dB", # f"PSNR: {psnr_mlem.item():.2f} dB", f"PSNR: {psnr_mlem_tv.item():.2f} dB", ], figsize=(12, 3), ) plot_curves(metrics_mlem_tv) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_003.png :alt: Ground Truth, Measurement, MLEM with TV :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_004.png :alt: $\text{PSNR}(x_k)$, $F(x_k)$, Residual $\frac{||x_{k+1} - x_k||}{||x_k||}$ :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_004.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 199-204 Computed Tomography with MLEM + TV + custom metrics ------------------------------------------------------------- In emission tomography (PET/SPECT), the forward model is a Radon transform with Poisson statistics. Here we take the simple Shepp-Logan phantom as ground truth and use MLEM with TV prior to reconstruct it from its noisy sinogram. .. GENERATED FROM PYTHON SOURCE LINES 204-221 .. code-block:: Python # Load a grayscale image val_transform_gray = transforms.Compose( [ transforms.CenterCrop(img_size), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), ] ) x_ct = load_example( "SheppLogan.png", img_size=img_size, grayscale=True, resize_mode="resize", device=device, ) .. GENERATED FROM PYTHON SOURCE LINES 222-225 Set up Tomography physics We define a parallel-beam tomography operator with 120 projection angles uniformly distributed between 0° and 180°, and Poisson noise. .. GENERATED FROM PYTHON SOURCE LINES 225-244 .. code-block:: Python num_angles = 120 gain_ct = 1 / 300 physics_ct = dinv.physics.Tomography( img_width=img_size, angles=num_angles, device=device, noise_model=dinv.physics.PoissonNoise( gain=gain_ct, normalize=True, clip_positive=True ), ) # Generate sinogram y_ct = physics_ct(x_ct) # Filtered back-projection as a simple baseline x_fbp = physics_ct.A_dagger(y_ct) .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/physics/tomography.py:187: UserWarning: The default value of `normalize` is not specified and will be automatically set to `True`. Set `normalize` explicitly to `True` or `False` to avoid this warning. warn( /local/jtachell/deepinv/deepinv/deepinv/physics/forward.py:487: UserWarning: Following torch.nn.Module's design, the 'device' attribute is deprecated and will be removed in a future version. To move the module's buffers/parameters to a different device, use the `to()` method. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 245-246 Run MLEM + TV on the CT problem .. GENERATED FROM PYTHON SOURCE LINES 246-263 .. code-block:: Python data_fidelity_ct = dinv.optim.PoissonLikelihood(gain=gain_ct) prior_tv_ct = dinv.optim.prior.TVPrior(n_it_max=50) model_ct = dinv.optim.MLEM( data_fidelity=data_fidelity_ct, prior=prior_tv_ct, lambda_reg=1e-2, max_iter=50, early_stop=True, thres_conv=1e-6, crit_conv="residual", verbose=True, ) x_ct_recon, metrics_ct = model_ct(y_ct, physics_ct, x_gt=x_ct, compute_metrics=True) .. GENERATED FROM PYTHON SOURCE LINES 264-265 Visualize CT results and plot convergence curves .. GENERATED FROM PYTHON SOURCE LINES 265-289 .. code-block:: Python psnr_fbp = dinv.metric.PSNR()(x_ct, x_fbp) psnr_ct = dinv.metric.PSNR()(x_ct, x_ct_recon) ssim_fbp = dinv.metric.SSIM()(x_ct, x_fbp) ssim_ct = dinv.metric.SSIM()(x_ct, x_ct_recon) plot( { "Ground Truth": x_ct, "Sinogram": y_ct, "FBP": x_fbp, "MLEM with TV": x_ct_recon, }, subtitles=[ "Reference", "Measurements", f"PSNR: {psnr_fbp.item():.2f} dB\nSSIM: {ssim_fbp.item():.3f}", f"PSNR: {psnr_ct.item():.2f} dB\nSSIM: {ssim_ct.item():.3f}", ], figsize=(12, 4), ) plot_curves(metrics_ct) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_005.png :alt: Ground Truth, Sinogram, FBP, MLEM with TV :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_006.png :alt: $\text{PSNR}(x_k)$, $F(x_k)$, Residual $\frac{||x_{k+1} - x_k||}{||x_k||}$ :srcset: /auto_examples/optimization/images/sphx_glr_demo_poisson_mlem_006.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 290-293 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 17.809 seconds) .. _sphx_glr_download_auto_examples_optimization_demo_poisson_mlem.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_poisson_mlem.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_poisson_mlem.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_poisson_mlem.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_