.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/optimization/demo_custom_prior.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_optimization_demo_custom_prior.py: Image deblurring with custom deep explicit prior. ==================================================================================================== In this example, we show how to solve a deblurring inverse problem using an explicit prior. Here we use the simple L2 prior that penalizes the squared norm of the reconstruction, with an ADMM algorithm. .. GENERATED FROM PYTHON SOURCE LINES 10-23 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch from torch.utils.data import DataLoader from deepinv.optim.prior import Prior from deepinv.optim.data_fidelity import L2 from deepinv.optim.optimizers import optim_builder from deepinv.training import test from torchvision import transforms from deepinv.utils.demo import load_dataset .. GENERATED FROM PYTHON SOURCE LINES 24-27 Setup paths for data loading and results. -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-40 .. code-block:: Python # Setup paths for data loading, results and checkpoints. BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" DEG_DIR = BASE_DIR / "degradations" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 41-45 Load base image datasets and degradation operators. -------------------------------------------------------------------- In this example, we use the CBSD68 dataset from the paper of Zhang et al. (2017) and the motion blur kernels from :footcite:t:`levin2009understanding`. .. GENERATED FROM PYTHON SOURCE LINES 45-57 .. code-block:: Python # Set up the variable to fetch dataset and operators. method = "L2_prior" dataset_name = "set3c" operation = "deblur" img_size = 256 val_transform = transforms.Compose( [transforms.CenterCrop(img_size), transforms.ToTensor()] ) dataset = load_dataset(dataset_name, transform=val_transform) .. GENERATED FROM PYTHON SOURCE LINES 58-65 Define physics operator -------------------------------------------------------- We use the :class:`deepinv.physics.BlurFFT` operator from the physics module to generate a dataset of blurred images. The BlurFFT class performs the convolutions via the Fourier transform. In this example, we choose a gaussian kernel with standard deviation 3, and we add a Gaussian noise with standard deviation 0.03. .. GENERATED FROM PYTHON SOURCE LINES 65-81 .. code-block:: Python # Generate a Gaussian blur filter. filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 # Gaussian Noise standard deviation for the degradation n_channels = 3 # 3 for color images, 1 for gray-scale images # The BlurFFT instance from physics enables to compute efficently backward operators with Fourier transform. p = dinv.physics.BlurFFT( img_size=(n_channels, img_size, img_size), filter=filter_torch, device=device, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), ) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Generate a dataset of blurred images -------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 84-101 .. code-block:: Python # Use parallel dataloader if using a GPU to speed up training, otherwise, as all computes are on CPU, use synchronous # data loading. num_workers = 4 if torch.cuda.is_available() else 0 n_images_max = 3 # Maximal number of images to restore from the input dataset measurement_dir = DATA_DIR / dataset_name / operation deepinv_dataset_path = dinv.datasets.generate_dataset( train_dataset=dataset, test_dataset=None, physics=p, device=device, save_dir=measurement_dir, train_datapoints=n_images_max, num_workers=num_workers, ) .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset has been saved at measurements/set3c/deblur/dinv_dataset0.h5 .. GENERATED FROM PYTHON SOURCE LINES 102-115 Set up the optimization algorithm to solve the inverse problem. -------------------------------------------------------------------------------------------- We use the :class:`deepinv.optim.optim_builder` function to instantiate the optimization algorithm. The optimization algorithm is a proximal gradient descent algorithm that solves the following optimization problem: .. math:: \min_{x} \frac{1}{2} \|y - Ax\|_2^2 + \lambda \|x\|_2^2 where :math:`A` is the forward blurring operator, :math:`y` is the measurement and :math:`\lambda` is a regularization parameter. .. GENERATED FROM PYTHON SOURCE LINES 115-163 .. code-block:: Python # Create a custom prior which inherits from the base Prior class. class L2Prior(Prior): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.explicit_prior = True def fn(self, x, args, **kwargs): return 0.5 * torch.norm(x.view(x.shape[0], -1), p=2, dim=-1) ** 2 # Specify the custom prior prior = L2Prior() # Select the data fidelity term data_fidelity = L2() # Specific parameters for restoration with the given prior (Note that these parameters have not been optimized here) params_algo = {"stepsize": 1, "lambda": 0.1} # Logging parameters verbose = True # Parameters of the algorithm to solve the inverse problem early_stop = True # Stop algorithm when convergence criteria is reached crit_conv = "cost" # Convergence is reached when the difference of cost function between consecutive iterates is # smaller than thres_conv thres_conv = 1e-5 backtracking = False # use backtraking to automatically adjust the stepsize max_iter = 500 # Maximum number of iterations # Instantiate the algorithm class to solve the IP problem. model = optim_builder( iteration="ADMM", prior=prior, g_first=False, data_fidelity=data_fidelity, params_algo=params_algo, early_stop=early_stop, max_iter=max_iter, crit_conv=crit_conv, thres_conv=thres_conv, backtracking=backtracking, verbose=verbose, ) .. GENERATED FROM PYTHON SOURCE LINES 164-168 Evaluate the reconstruction algorithm on the problem. --------------------------------------------------------------- We can use the :func:`deepinv.test` function to evaluate the reconstruction algorithm on a test set. .. GENERATED FROM PYTHON SOURCE LINES 168-191 .. code-block:: Python batch_size = 1 plot_images = True # plot results plot_convergence_metrics = True # compute performance and convergence metrics along the algorithm, curves saved in RESULTS_DIR dataset = dinv.datasets.HDF5Dataset(path=deepinv_dataset_path, train=True) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False ) test( model=model, test_dataloader=dataloader, physics=p, device=device, plot_images=plot_images, save_folder=RESULTS_DIR / method / operation / dataset_name, plot_convergence_metrics=plot_convergence_metrics, verbose=verbose, ) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_custom_prior_001.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/optimization/images/sphx_glr_demo_custom_prior_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/optimization/images/sphx_glr_demo_custom_prior_002.png :alt: PSNR, F, residual :srcset: /auto_examples/optimization/images/sphx_glr_demo_custom_prior_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/3 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_custom_prior.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_custom_prior.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_