.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plug-and-play/demo_PnP_custom_optim.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plug-and-play_demo_PnP_custom_optim.py: PnP with custom optimization algorithm (Primal-Dual Condat-Vu) ==================================================================================================== This example shows how to define your own optimization algorithm. For example, here, we implement the Primal-Dual Condat-Vu (CV) algorithm, and apply it for Single Pixel Camera reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 9-21 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch from deepinv.models import DnCNN from deepinv.optim.data_fidelity import L2 from deepinv.optim.prior import PnP from deepinv.optim.optimizers import BaseOptim from deepinv.utils import load_example from deepinv.utils.plotting import plot, plot_curves from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep .. GENERATED FROM PYTHON SOURCE LINES 22-38 Define a custom optimization algorithm ---------------------------------------------------------------------------------------- Creating your optimization algorithm only requires the definition of an iteration step. The iterator should be a subclass of :class:`deepinv.optim.OptimIterator`. The Condat-Vu Primal-Dual algorithm is defined as follows: .. math:: \begin{align*} v_k &= x_k-\tau A^\top z_k \\ x_{k+1} &= \operatorname{prox}_{\tau g}(v_k) \\ u_k &= z_k + \sigma A(2x_{k+1}-x_k) \\ z_{k+1} &= \operatorname{prox}_{\sigma f^*}(u_k) \end{align*} where :math:`f^*` is the Fenchel-Legendre conjugate of :math:`f`. .. GENERATED FROM PYTHON SOURCE LINES 38-80 .. code-block:: Python class CVIteration(OptimIterator): r""" Single iteration of Condat-Vu Primal-Dual. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.g_step = gStepCV(**kwargs) self.f_step = fStepCV(**kwargs) def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics): r""" Single iteration of the Condat-Vu algorithm. :param dict X: Dictionary containing the current iterate and the estimated cost. :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. :param dict cur_prior: dictionary containing the prior-related term of interest, e.g. its proximal operator or gradient. :param dict cur_params: dictionary containing the current parameters of the model. :param torch.Tensor y: Input data. :param deepinv.physics physics: Instance of the physics modeling the data-fidelity term. :return: Dictionary `{"est": (x,z), "cost": F}` containing the updated current iterate and the estimated current cost. """ x_prev, z_prev = X["est"] v = x_prev - cur_params["stepsize"] * physics.A_adjoint(z_prev) x = self.g_step(v, cur_prior, cur_params) u = z_prev + cur_params["stepsize"] * physics.A(2 * x - x_prev) z = self.f_step(u, cur_data_fidelity, cur_params, y, physics) F = ( self.F_fn(x, cur_data_fidelity, cur_params, y, physics) if self.has_cost and self.F_fn is not None and cur_data_fidelity is not None and cur_prior is not None else None ) return {"est": (x, z), "cost": F} .. GENERATED FROM PYTHON SOURCE LINES 81-100 Define the custom fStep and gStep modules ---------------------------------------------------------------------------------------- The iterator relies on custom fStepCV (subclass of :class:`deepinv.optim.optim_iterators.fStep`) and gStepCV (subclass of :class:`deepinv.optim.optim_iterators.gStep`) modules. In this case the fStep module is defined as follows: .. math:: u_{k+1} = \operatorname{prox}_{\sigma f^*}(u_k) where :math:`f^*` is the Fenchel-Legendre conjugate of :math:`f`. The proximal operator of :math:`f^*` is computed using the proximal operator of :math:`f` via Moreau's identity, and the gStep module is a simple proximal step on the prior term :math:`\lambda g`: .. math:: x_{k+1} = \operatorname{prox}_{\tau \lambda g}(v_k) .. GENERATED FROM PYTHON SOURCE LINES 100-149 .. code-block:: Python class fStepCV(fStep): r""" Condat-Vu fStep module to compute :math:`\operatorname{prox}_{\sigma f^*}(z_k)`` """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, u, cur_data_fidelity, cur_params, y, phyics): r""" Single iteration on the data-fidelity term :math:`f`. :param torch.Tensor z: Current iterate :math:`z_k = 2Ax_{k+1}-x_k` :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. :param dict cur_params: Dictionary containing the current fStep parameters (keys `"stepsize"` and `"lambda"`). :param torch.Tensor y: Input data. :param deepinv.physics physics: Instance of the physics modeling the data-fidelity term. """ return cur_data_fidelity.d.prox_conjugate( u, y, gamma=cur_params["stepsize_dual"] ) class gStepCV(gStep): r""" Condat-Vu gStep module to compute :math:`\operatorname{prox}_{\tau g}(v_k)` """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, v, cur_prior, cur_params): r""" Single iteration step on the prior term :math:`\lambda g`. :param torch.Tensor x: Current iterate :math:`v_k = x_k-\tau A^\top u_k`. :param dict cur_prior: Dictionary containing the current prior. :param dict cur_params: Dictionary containing the current gStep parameters (keys `"stepsize"` and `"g_param"`). """ return cur_prior.prox( v, cur_params["g_param"], gamma=cur_params["lambda"] * cur_params["stepsize"], ) .. GENERATED FROM PYTHON SOURCE LINES 150-153 Define the Conva-Vu model as a subclass of :class:`deepinv.optim.BaseOptim`, in the model of other optimizer of the library, see for example :class:`deepinv.optim.ADMM`. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 153-188 .. code-block:: Python class CV(BaseOptim): r""" Primal-Dual Condat-Vu (CV) optimization algorithm. """ def __init__( self, data_fidelity=None, prior=None, lambda_reg=1.0, stepsize=1.0, stepsize_dual=1.0, beta=1.0, sigma_denoiser=None, **kwargs, ): params_algo = { "lambda": lambda_reg, "stepsize": stepsize, "stepsize_dual": stepsize_dual, "g_param": sigma_denoiser, "beta": beta, } super(CV, self).__init__( CVIteration(), params_algo=params_algo, data_fidelity=data_fidelity, prior=prior, **kwargs, ) .. GENERATED FROM PYTHON SOURCE LINES 189-192 Setup paths for data loading and results. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 192-197 .. code-block:: Python BASE_DIR = Path(".") RESULTS_DIR = BASE_DIR / "results" .. GENERATED FROM PYTHON SOURCE LINES 198-200 Load base image datasets and degradation operators. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 200-222 .. code-block:: Python # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" # Set up the variable to fetch dataset and operators. method = "PnP" dataset_name = "set3c" img_size = 64 x = load_example( "barbara.jpeg", img_size=img_size, grayscale=True, resize_mode="resize", device=device, ) operation = "single_pixel" .. GENERATED FROM PYTHON SOURCE LINES 223-228 Set the forward operator -------------------------------------------------------------------------------- We use the :class:`deepinv.physics.SinglePixelCamera` class from the physics module to generate a single-pixel measurements. The forward operator consists of the multiplication with the low frequencies of the Hadamard transform. .. GENERATED FROM PYTHON SOURCE LINES 228-244 .. code-block:: Python noise_level_img = 0.03 # Gaussian Noise standard deviation for the degradation n_channels = 1 # 3 for color images, 1 for gray-scale images physics = dinv.physics.SinglePixelCamera( m=600, img_size=(1, 64, 64), ordering="cake_cutting", noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device, ) # Use parallel dataloader if using a GPU to speed up training, # otherwise, as all computes are on CPU, use synchronous data loading. num_workers = 4 if torch.cuda.is_available() else 0 .. GENERATED FROM PYTHON SOURCE LINES 245-253 Set up the PnP algorithm to solve the inverse problem. -------------------------------------------------------------------------------- We build the PnP model using our custom :func:`deepinv.optim.PDCP` function. The primal dual stepsizes :math:`\tau` corresponds to the ``stepsize`` key and :math:`\sigma` to the ``sigma`` key. For the denoiser, we choose the 1-Lipschitz grayscale DnCNN model (see the :ref:`pretrained-weights `). .. GENERATED FROM PYTHON SOURCE LINES 253-285 .. code-block:: Python # Set up the PnP algorithm parameters : stepsize = 0.99 # primal stepsize stepsize_dual = 0.99 # dual stepsize sigma_denoiser = 0.01 # denoiser parameter (noise level) max_iter = 200 early_stop = True # stop the algorithm when convergence is reached # Select the data fidelity term data_fidelity = L2() # Specify the denoising prior denoiser = DnCNN( in_channels=n_channels, out_channels=n_channels, pretrained="download_lipschitz", device=device, ) prior = PnP(denoiser=denoiser) # instantiate the algorithm class to solve the IP problem. model = CV( prior=prior, data_fidelity=data_fidelity, stepsize=stepsize, stepsize_dual=stepsize_dual, sigma_denoiser=sigma_denoiser, early_stop=early_stop, max_iter=max_iter, verbose=True, ) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/dncnn/resolve/main/dncnn_sigma2_lipschitz_gray.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/dncnn_sigma2_lipschitz_gray.pth 0%| | 0.00/2.55M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_PnP_custom_optim.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_PnP_custom_optim.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_