.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/plug-and-play/demo_PnP_mirror_descent.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_plug-and-play_demo_PnP_mirror_descent.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_plug-and-play_demo_PnP_mirror_descent.py:


Plug-and-Play algorithm with Mirror Descent for Poisson noise inverse problems.
====================================================================================================

This is a simple example to show how to use a mirror descent algorithm for solving an inverse problem with Poisson noise.

The Mirror descent with RED denoiser writes

.. math::

    x_{k+1} = \nabla \phi ( \nabla \phi^*(x_k) - \tau \nabla \distance{A(x_k)}{y} - \tau ( x_k - D_\sigma(x)))

where :math:`\phi` is a convex Bergman potential, :math:`\distance{A(x)}{y}` is the data fidelity term and :math:`D_\sigma(x)` is a denoiser.

In this example, we use the DnCNN denoiser. As the observation has been corrupted with Poisson noise, we use the :class:`deepinv.optim.PoissonLikelihood` data-fidelity term.
In https://publications.ut-capitole.fr/id/eprint/25852/1/25852.pdf, it is shown that, with this data-fidelity term, the right Bregman potential to use is Burg's entropy :class:`deepinv.optim.bregman.BurgEntropy`.

.. GENERATED FROM PYTHON SOURCE LINES 18-30

.. code-block:: Python


    import deepinv as dinv
    from pathlib import Path
    import torch
    from torch.utils.data import DataLoader
    from deepinv.optim.data_fidelity import PoissonLikelihood
    from deepinv.optim.prior import RED
    from deepinv.optim import optim_builder
    from deepinv.optim.bregman import BurgEntropy
    from deepinv.utils.demo import load_url_image, get_image_url
    from deepinv.utils.plotting import plot, plot_curves








.. GENERATED FROM PYTHON SOURCE LINES 31-34

Setup paths for data loading and results.
----------------------------------------------------------------------------------------


.. GENERATED FROM PYTHON SOURCE LINES 34-65

.. code-block:: Python


    BASE_DIR = Path(".")
    ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
    DATA_DIR = BASE_DIR / "measurements"
    RESULTS_DIR = BASE_DIR / "results"
    CKPT_DIR = BASE_DIR / "ckpts"

    # Set the global random seed from pytorch to ensure reproducibility of the example.
    torch.manual_seed(0)

    img_size = 64
    device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
    url = get_image_url("butterfly.png")
    x_true = load_url_image(url=url, img_size=img_size).to(device)
    x = x_true.clone()


    n_channels = 3  # 3 for color images, 1 for gray-scale images
    operation = "deblurring"

    # Degradation parameters
    noise_level_img = 1 / 40  # Poisson Noise gain

    # Generate the gaussian blur operator with Poisson noise.
    physics = dinv.physics.BlurFFT(
        img_size=(n_channels, img_size, img_size),
        filter=dinv.physics.blur.gaussian_blur(),
        device=device,
        noise_model=dinv.physics.PoissonNoise(gain=noise_level_img),
    )








.. GENERATED FROM PYTHON SOURCE LINES 66-69

Define the PnP algorithm.
----------------------------------------------------------------------------------------
The chosen algorithm is here MD (Mirror Descent).

.. GENERATED FROM PYTHON SOURCE LINES 69-102

.. code-block:: Python



    # Select the data fidelity term, here Poisson likelihood due to the use of Poisson noise in the forward operator.
    data_fidelity = PoissonLikelihood(gain=noise_level_img)

    # Set up the denoising prior. Note that we use a Gaussian noise denoiser, even if the observation noise is Poisson.
    prior = RED(denoiser=dinv.models.DnCNN(depth=20, pretrained="download").to(device))

    # Set up the optimization parameters
    max_iter = 200  # number of iterations
    stepsize = 1.0  # stepsize of the algorithm
    sigma_denoiser = 0.05  # noise level parameter of the Gaussian denoiser
    params_algo = {  # wrap all the restoration parameters in a 'params_algo' dictionary. In particular, this is here that we define the bregman potential used in the mirror descent algorithm.
        "stepsize": stepsize,
        "g_param": sigma_denoiser,
    }

    # Logging parameters
    verbose = True

    # Define the unfolded trainable model.
    model = optim_builder(
        iteration="MD",
        prior=prior,
        data_fidelity=data_fidelity,
        early_stop=True,
        max_iter=max_iter,
        verbose=verbose,
        params_algo=params_algo,
        bregman_potential=BurgEntropy(),
    )









.. GENERATED FROM PYTHON SOURCE LINES 103-108

Evaluate the model on the problem and plot the results.
--------------------------------------------------------------------

The model returns the output and the metrics computed along the iterations.
For computing PSNR, the ground truth image ``x_gt`` must be provided.

.. GENERATED FROM PYTHON SOURCE LINES 108-133

.. code-block:: Python


    y = physics(x)
    x_lin = physics.A_adjoint(y)

    # run the model on the problem.
    with torch.no_grad():
        x_model, metrics = model(
            y, physics, x_gt=x, compute_metrics=True
        )  # reconstruction with PnP algorithm

    # compute PSNR
    print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
    print(f"PnP reconstruction PSNR: {dinv.metric.PSNR()(x, x_model).item():.2f} dB")

    # plot images. Images are saved in RESULTS_DIR.
    imgs = [y, x, x_lin, x_model]
    plot(
        imgs,
        titles=["Input", "GT", "Linear", "Recons."],
        save_dir=RESULTS_DIR / "images",
        show=True,
    )

    # plot convergence curves. Metrics are saved in RESULTS_DIR.
    plot_curves(metrics, save_dir=RESULTS_DIR / "curves", show=True)



.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /auto_examples/plug-and-play/images/sphx_glr_demo_PnP_mirror_descent_001.png
         :alt: Input, GT, Linear, Recons.
         :srcset: /auto_examples/plug-and-play/images/sphx_glr_demo_PnP_mirror_descent_001.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/plug-and-play/images/sphx_glr_demo_PnP_mirror_descent_002.png
         :alt: PSNR, residual
         :srcset: /auto_examples/plug-and-play/images/sphx_glr_demo_PnP_mirror_descent_002.png
         :class: sphx-glr-multi-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Linear reconstruction PSNR: 20.97 dB
    PnP reconstruction PSNR: 23.72 dB





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 7.907 seconds)


.. _sphx_glr_download_auto_examples_plug-and-play_demo_PnP_mirror_descent.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: demo_PnP_mirror_descent.ipynb <demo_PnP_mirror_descent.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: demo_PnP_mirror_descent.py <demo_PnP_mirror_descent.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: demo_PnP_mirror_descent.zip <demo_PnP_mirror_descent.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_