.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plug-and-play/demo_vanilla_PnP.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plug-and-play_demo_vanilla_PnP.py: Vanilla PnP for computed tomography (CT). ==================================================================================================== This example shows how to use a standart PnP algorithm with DnCNN denoiser for computed tomography. .. GENERATED FROM PYTHON SOURCE LINES 8-19 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch from deepinv.models import DnCNN from deepinv.optim.data_fidelity import L2 from deepinv.optim.prior import PnP from deepinv.optim.optimizers import optim_builder from deepinv.utils.demo import load_url_image, get_image_url from deepinv.utils.plotting import plot, plot_curves .. GENERATED FROM PYTHON SOURCE LINES 20-23 Setup paths for data loading and results. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 23-26 .. code-block:: Python BASE_DIR = Path(".") RESULTS_DIR = BASE_DIR / "results" .. GENERATED FROM PYTHON SOURCE LINES 27-29 Load image and parameters ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-43 .. code-block:: Python # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" # Set up the variable to fetch dataset and operators. method = "PnP" img_size = 32 url = get_image_url("SheppLogan.png") x = load_url_image( url=url, img_size=img_size, grayscale=True, resize_mode="resize", device=device ) operation = "tomography" .. GENERATED FROM PYTHON SOURCE LINES 44-48 Set the forward operator -------------------------------------------------------------------------------- We use the :class:`deepinv.physics.Tomography` class from the physics module to generate a CT measurements. .. GENERATED FROM PYTHON SOURCE LINES 48-69 .. code-block:: Python noise_level_img = 0.03 # Gaussian Noise standard deviation for the degradation angles = 100 n_channels = 1 # 3 for color images, 1 for gray-scale images physics = dinv.physics.Tomography( img_width=img_size, angles=angles, circle=False, device=device, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), ) PI = 4 * torch.ones(1).atan() SCALING = (PI / (2 * angles)).to(device) # approximate operator norm of A^T A # Use parallel dataloader if using a GPU to fasten training, # otherwise, as all computes are on CPU, use synchronous data loading. num_workers = 4 if torch.cuda.is_available() else 0 .. GENERATED FROM PYTHON SOURCE LINES 70-80 Set up the PnP algorithm to solve the inverse problem. -------------------------------------------------------------------------------- We use the Proximal Gradient Descent optimization algorithm. The algorithm alternates between a denoising step and a gradient descent step. The denoising step is performed by a DNCNN pretrained denoiser :class:`deepinv.models.DnCNN`. Set up the PnP algorithm parameters : the ``stepsize``, ``g_param`` the noise level of the denoiser. Attention: The choice of the stepsize is crucial as it also defines the amount of regularization. Indeed, the regularization parameter ``lambda`` is implicitly defined by the stepsize. Both the stepsize and the noise level of the denoiser control the regularization power and should be tuned to the specific problem. The following parameters have been chosen manually. .. GENERATED FROM PYTHON SOURCE LINES 80-119 .. code-block:: Python params_algo = {"stepsize": 0.01 * SCALING, "g_param": noise_level_img} max_iter = 100 early_stop = True # Select the data fidelity term data_fidelity = L2() # Specify the denoising prior denoiser = DnCNN( in_channels=n_channels, out_channels=n_channels, pretrained="download", # automatically downloads the pretrained weights, set to a path to use custom weights. device=device, ) prior = PnP(denoiser=denoiser) # Logging parameters verbose = True plot_convergence_metrics = True # compute performance and convergence metrics along the algorithm, curves saved in RESULTS_DIR # instantiate the algorithm class to solve the IP problem. # initialize with the rescaled adjoint such that the initialization lives already at the correct scale model = optim_builder( iteration="PGD", prior=prior, data_fidelity=data_fidelity, early_stop=early_stop, max_iter=max_iter, verbose=verbose, params_algo=params_algo, custom_init=lambda y, physics: { "est": (physics.A_adjoint(y) * SCALING, physics.A_adjoint(y) * SCALING) }, ) # Set the model to evaluation mode. We do not require training here. model.eval() .. rst-class:: sphx-glr-script-out .. code-block:: none BaseOptim( (fixed_point): FixedPoint( (iterator): PGDIteration( (f_step): fStepPGD() (g_step): gStepPGD() ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 120-125 Evaluate the model on the problem and plot the results. -------------------------------------------------------------------- The model returns the output and the metrics computed along the iterations. For computing PSNR, the ground truth image ``x_gt`` must be provided. .. GENERATED FROM PYTHON SOURCE LINES 125-152 .. code-block:: Python y = physics(x) x_lin = ( physics.A_adjoint(y) * SCALING ) # rescaled linear reconstruction with the adjoint operator # run the model on the problem. x_model, metrics = model( y, physics, x_gt=x, compute_metrics=True ) # reconstruction with PnP algorithm # compute PSNR print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB") print(f"PnP reconstruction PSNR: {dinv.metric.PSNR()(x, x_model).item():.2f} dB") # plot images. Images are saved in RESULTS_DIR. imgs = [y, x, x_lin, x_model] plot( imgs, titles=["Input", "GT", "Linear", "Recons."], save_dir=RESULTS_DIR / "images", show=True, ) # plot convergence curves. Metrics are saved in RESULTS_DIR. if plot_convergence_metrics: plot_curves(metrics, save_dir=RESULTS_DIR / "curves", show=True) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/plug-and-play/images/sphx_glr_demo_vanilla_PnP_001.png :alt: Input, GT, Linear, Recons. :srcset: /auto_examples/plug-and-play/images/sphx_glr_demo_vanilla_PnP_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/plug-and-play/images/sphx_glr_demo_vanilla_PnP_002.png :alt: PSNR, residual :srcset: /auto_examples/plug-and-play/images/sphx_glr_demo_vanilla_PnP_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none Linear reconstruction PSNR: -16.02 dB PnP reconstruction PSNR: 30.11 dB .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.005 seconds) .. _sphx_glr_download_auto_examples_plug-and-play_demo_vanilla_PnP.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_vanilla_PnP.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_vanilla_PnP.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_vanilla_PnP.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_