.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_custom_optim.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_custom_optim.py: Use iterative reconstruction algorithms ==================================================================================================== Follow this example to reconstruct images using an iterative algorithm. The library provides a flexible framework to define your own iterative reconstruction algorithm, which are generally written as the optimization of the following problem: .. math:: \begin{equation} \label{eq:min_prob} \tag{1} \underset{x}{\arg\min} \quad \datafid{x}{y} + \lambda \reg{x}, \end{equation} where :math:`\datafid{x}{y}` is the data fidelity term, :math:`\reg{x}` is the (explicit or implicit) regularization term, and :math:`\lambda` is a regularization parameter. In this example, we demonstrate: 1. How to define your own iterative algorithm 2. How to package it as a :class:`reconstructor model ` 3. How to use predefined algorithms using :class:`optim builder ` 1. Defining your own iterative algorithm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 27-34 .. code-block:: Python import deepinv as dinv import torch device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 35-39 Define the physics of the problem ----------------------------------- Here we define a simple inpainting problem, where we want to reconstruct an image from partial measurements. We also load an image of a butterfly to use as ground truth. .. GENERATED FROM PYTHON SOURCE LINES 39-50 .. code-block:: Python x = dinv.utils.load_example("butterfly.png", device=device, img_size=(128, 128)) # Forward operator, here inpainting with a mask of 50% of the pixels physics = dinv.physics.Inpainting(img_size=(3, 128, 128), mask=0.5, device=device) # Generate measurements y = physics(x) dinv.utils.plot([x, y], titles=["Ground truth", "Measurements"]) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_001.png :alt: Ground truth, Measurements :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 51-59 Define the data fidelity term and prior --------------------------------------- The library provides a set of :ref:`data fidelity ` terms and :ref:`priors ` that can be used in the optimization problem. Here we use the :math:`\ell_2` data fidelity term and the Total Variation (TV) prior. These classes provide all the necessary methods for the optimization problem, such as the evaluation of the term, the gradient, and the proximal operator. .. GENERATED FROM PYTHON SOURCE LINES 59-64 .. code-block:: Python data_fidelity = dinv.optim.L2() # Data fidelity term prior = dinv.optim.TVPrior() # Prior term .. GENERATED FROM PYTHON SOURCE LINES 65-79 Define the iterative algorithm ----------------------------------- We will use the Proximal Gradient Descent (PGD) algorithm to solve the optimization problem defined above, which is defined as .. math:: \qquad x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname} \left( x_k - \gamma \nabla \datafidname(x_k, y) \right), where :math:`\operatorname{prox}_{\gamma \lambda \regname}` is the proximal operator of the regularization term, :math:`\nabla \datafidname(x_k, y)` is the gradient of the data fidelity term, :math:`\gamma` is the stepsize. and :math:`\lambda` is the regularization parameter. We can choose the stepsize as :math:`\gamma < \frac{2}{\|A\|^2}`, where :math:`A` is the forward operator, in order to ensure convergence of the algorithm. .. GENERATED FROM PYTHON SOURCE LINES 79-101 .. code-block:: Python lambd = 0.05 # Regularization parameter # Compute the squared norm of the operator A norm_A2 = physics.compute_norm(y, tol=1e-4, verbose=False).item() stepsize = 1.9 / norm_A2 # stepsize for the PGD algorithm # PGD algorithm max_iter = 20 # number of iterations x_k = torch.zeros_like(x, device=device) # initial guess # To store the cost at each iteration: cost_history = torch.zeros(max_iter, device=device) with torch.no_grad(): # disable autodifferentiation for it in range(max_iter): u = x_k - stepsize * data_fidelity.grad(x_k, y, physics) # Gradient step x_k = prior.prox(u, gamma=lambd * stepsize) # Proximal step cost = data_fidelity(x_k, y, physics) + lambd * prior(x_k) # Compute the cost cost_history[it] = cost # Store the cost .. GENERATED FROM PYTHON SOURCE LINES 102-103 Plot the cost history .. GENERATED FROM PYTHON SOURCE LINES 103-113 .. code-block:: Python import matplotlib.pyplot as plt plt.figure(figsize=(8, 4)) plt.plot(cost_history.detach().cpu().numpy(), marker="o") plt.title("Cost history") plt.xlabel("Iteration") plt.ylabel("Cost") plt.grid() plt.show() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_002.png :alt: Cost history :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 114-115 Plot the results and metrics .. GENERATED FROM PYTHON SOURCE LINES 115-127 .. code-block:: Python metric = dinv.metric.PSNR() dinv.utils.plot( { f"Ground truth": x, f"Measurements\n {metric(y, x).item():.2f} dB": y, f"Recon w/ TV prior\n {metric(x_k, x).item():.2f} dB": x_k, } ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_003.png :alt: Ground truth, Measurements 7.48 dB, Recon w/ TV prior 25.45 dB :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 128-139 Use a pretrained denoiser as prior ---------------------------------- We can improve the reconstruction by using a pretrained denoiser as prior, by replacing the proximal operator with a denoising step. The library provides :ref:`a collection of classical and pretrained denoisers ` that can be used in iterative algorithms. .. note:: Plug-and-play algorithms can be sensitive to the choice of initialization. Here we use the TV estimate as the initial guess. .. GENERATED FROM PYTHON SOURCE LINES 139-158 .. code-block:: Python x_k = x_k.clone() denoiser = dinv.models.DRUNet(device=device) # Load a pretrained denoiser with torch.no_grad(): # disable autodifferentiation for it in range(max_iter): u = x_k - stepsize * data_fidelity.grad(x_k, y, physics) # Gradient step x_k = denoiser(u, sigma=0.05) # replace prox by denoising step dinv.utils.plot( { f"Ground truth": x, f"Measurements\n {metric(y, x).item():.2f} dB": y, f"Recon w/ PnP prior\n {metric(x_k, x).item():.2f} dB": x_k, } ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_004.png :alt: Ground truth, Measurements 7.48 dB, Recon w/ PnP prior 30.25 dB :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 159-165 2. Package your algorithm as a Reconstructor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The iterative algorithm we defined above can be packaged as a :class:`Reconstructor `. This allows you to :class:`test it ` on different physics and datasets, and to use it in a more flexible way, including unfolding it and learning some of its parameters. .. GENERATED FROM PYTHON SOURCE LINES 165-211 .. code-block:: Python class MyPGD(dinv.models.Reconstructor): def __init__(self, data_fidelity, prior, stepsize, lambd, max_iter): super().__init__() self.data_fidelity = data_fidelity self.prior = prior self.stepsize = stepsize self.lambd = lambd self.max_iter = max_iter def forward(self, y, physics, **kwargs): """Algorithm forward pass. :param torch.Tensor y: measurements. :param dinv.physics.Physics physics: measurement operator. :return: torch.Tensor: reconstructed image. """ x_k = torch.zeros_like(y, device=y.device) # initial guess # Disable autodifferentiation, remove this if you want to unfold with torch.no_grad(): for _ in range(self.max_iter): u = x_k - self.stepsize * self.data_fidelity.grad( x_k, y, physics ) # Gradient step x_k = self.prior.prox( u, gamma=self.lambd * self.stepsize ) # Proximal step return x_k tv_algo = MyPGD(data_fidelity, prior, stepsize, lambd, max_iter) # Standard reconstructor forward pass x_hat = tv_algo(y, physics) dinv.utils.plot( { f"Ground truth": x, f"Measurements\n {metric(y, x).item():.2f} dB": y, f"Recon w/ custom PGD\n {metric(x_hat, x).item():.2f} dB": x_hat, } ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_005.png :alt: Ground truth, Measurements 7.48 dB, Recon w/ custom PGD 25.45 dB :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 212-221 3. Using a predefined optimization algorithm with `optim_builder` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The library also lets you define :ref:`standard optimization algorithms ` as standard :class:`Reconstructors ` in one line of code using the :class:`deepinv.optim.optim_builder` function. For example, the above PnP algorithm can be defined as follows: .. seealso:: See :ref:`the optimization examples ` for more examples of using `optim_builder`. .. GENERATED FROM PYTHON SOURCE LINES 221-256 .. code-block:: Python prior = dinv.optim.PnP(denoiser=denoiser) # prior with prox via denoising step def custom_init(y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: """ Custom initialization function for the optimization algorithm. The function should return a dictionary with the key "est" containing a tuple with the initial guess (the TV solution in this case) and the dual variables (None in this case). """ primal = tv_algo(y, physics) dual = None # No dual variables in this case return {"est": (primal, dual)} model = dinv.optim.optim_builder( iteration="PGD", prior=prior, data_fidelity=data_fidelity, params_algo={"stepsize": stepsize, "g_param": 0.05}, max_iter=max_iter, custom_init=custom_init, ) x_hat = model(y, physics) dinv.utils.plot( { f"Ground truth": x, f"Measurements\n {metric(y, x).item():.2f} dB": y, f"Reconstruction\n {metric(x_hat, x).item():.2f} dB": x_hat, } ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_optim_006.png :alt: Ground truth, Measurements 7.48 dB, Reconstruction 30.25 dB :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_optim_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 257-266 🎉 Well done, you now know how to define your own iterative reconstruction algorithm! What's next? ~~~~~~~~~~~~ * Check out more about optimization algorithms in the :ref:`optimization user guide `. * Check out diffusion and MCMC iterative algorithms in the :ref:`sampling user guide `. * Check out more :ref:`iterative algorithms examples `. * Check out how to try the algorithm on a whole dataset by following the :ref:`bring your own dataset ` tutorial. .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 27.407 seconds) .. _sphx_glr_download_auto_examples_basics_demo_custom_optim.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_custom_optim.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_custom_optim.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_custom_optim.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_