.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unfolded/demo_LISTA.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unfolded_demo_LISTA.py: Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing ==================================================================================================== This example shows how to implement the LISTA algorithm :footcite:t:`gregor2010learning`, for a compressed sensing problem. In a nutshell, LISTA is an unfolded proximal gradient algorithm involving a soft-thresholding proximal operator with learnable thresholding parameters. .. GENERATED FROM PYTHON SOURCE LINES 10-23 .. code-block:: Python from pathlib import Path import torch from torchvision import datasets from torchvision import transforms import deepinv as dinv from torch.utils.data import DataLoader from deepinv.optim.data_fidelity import L2 from deepinv.optim import PGD from deepinv.utils import get_cache_home from deepinv.models.utils import get_weights_url .. GENERATED FROM PYTHON SOURCE LINES 24-27 Setup paths for data loading and results. ----------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-39 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" CKPT_DIR = BASE_DIR / "ckpts" ORIGINAL_DATA_DIR = get_cache_home() / "datasets" / "MNIST" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_device() .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 5784.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 40-43 Load base image datasets and degradation operators. ---------------------------------------------------------------------------------------- In this example, we use MNIST as the base dataset. .. GENERATED FROM PYTHON SOURCE LINES 43-59 .. code-block:: Python img_size = 28 n_channels = 1 operation = "compressed-sensing" train_dataset_name = "MNIST_train" # Generate training and evaluation datasets in HDF5 folders and load them. train_test_transform = transforms.Compose([transforms.ToTensor()]) train_base_dataset = datasets.MNIST( root=ORIGINAL_DATA_DIR, train=True, transform=train_test_transform, download=True ) test_base_dataset = datasets.MNIST( root=ORIGINAL_DATA_DIR, train=False, transform=train_test_transform, download=True ) .. GENERATED FROM PYTHON SOURCE LINES 60-67 Generate a dataset of compressed measurements and load it. ---------------------------------------------------------------------------- We use the compressed sensing class from the physics module to generate a dataset of highly-compressed measurements (10% of the total number of pixels). The forward operator is defined as :math:`y = Ax` where :math:`A` is a (normalized) random Gaussian matrix. .. GENERATED FROM PYTHON SOURCE LINES 67-97 .. code-block:: Python # Use parallel dataloader if using a GPU to speed up training, otherwise, as all computes are on CPU, use synchronous # data loading. num_workers = 4 if torch.cuda.is_available() else 0 # Generate the compressed sensing measurement operator with 5x under-sampling factor. physics = dinv.physics.CompressedSensing( m=157, img_size=(n_channels, img_size, img_size), fast=True, device=device ) my_dataset_name = "demo_LISTA" n_images_max = ( 5000 if torch.cuda.is_available() else 200 ) # maximal number of images used for training measurement_dir = DATA_DIR / train_dataset_name / operation generated_datasets_path = dinv.datasets.generate_dataset( train_dataset=train_base_dataset, test_dataset=test_base_dataset, physics=physics, device=device, save_dir=measurement_dir, train_datapoints=n_images_max, test_datapoints=8, num_workers=num_workers, dataset_filename=str(my_dataset_name), ) train_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=True) test_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=False) .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset has been saved at measurements/MNIST_train/compressed-sensing/demo_LISTA0.h5 .. GENERATED FROM PYTHON SOURCE LINES 98-123 Define the unfolded Proximal Gradient algorithm. ------------------------------------------------------------------------ In this example, following the original LISTA algorithm :footcite:t:`gregor2010learning` the backbone algorithm we unfold is the proximal gradient algorithm which minimizes the following objective function .. math:: \begin{equation} \tag{1} \min_x \frac{1}{2} \|y - Ax\|_2^2 + \lambda \|Wx\|_1 \end{equation} where :math:`\lambda` is the regularization parameter. The proximal gradient iteration (see also :class:`deepinv.optim.optim_iterators.PGDIteration`) is defined as .. math:: x_{k+1} = \text{prox}_{\gamma \lambda g}(x_k - \gamma A^T (Ax_k - y)) where :math:`\gamma` is the stepsize and :math:`\text{prox}_{g}` is the proximity operator of :math:`g(x) = \|Wx\|_1` which corresponds to soft-thresholding with a wavelet basis (see :class:`deepinv.optim.WaveletPrior`). We use :func:`deepinv.optim.PGD` with `unfold=True` to define the unfolded algorithm and set both the stepsizes of the LISTA algorithm :math:`\gamma` (``stepsize``) and the soft thresholding parameters :math:`\lambda` as learnable parameters. These parameters are initialized with a table of length max_iter, yielding a distinct ``stepsize`` value for each iteration of the algorithm. .. GENERATED FROM PYTHON SOURCE LINES 123-140 .. code-block:: Python # Select the data fidelity term data_fidelity = L2() max_iter = 10 # Number of unrolled iterations stepsize = [torch.ones(1, device=device)] * max_iter # initialization of the stepsizes. # A distinct stepsize is trained for each iteration. # Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis. # If the prior is initialized with a list of length max_iter, # then a distinct weight is trained for each PGD iteration. # For fixed trained model prior across iterations, initialize with a single model. level = 3 prior = [ dinv.optim.WaveletPrior(wv="db8", level=level, device=device) for i in range(max_iter) ] .. GENERATED FROM PYTHON SOURCE LINES 141-153 In practice, it is common to apply a different thresholding parameter for each wavelet sub-band. This means that the thresholding parameter is a tensor of shape (n_levels, n_wavelet_subbands) and the associated problem (1) is reformulated as .. math:: \begin{equation} \min_x \frac{1}{2} \|y - Ax\|_2^2 + \sum_{i, j} \lambda_{i, j} \|\left(Wx\right)_{i, j}\|_1 \end{equation} where :math:`\lambda_{i, j}` is the thresholding parameter for the wavelet sub-band :math:`j` at level :math:`i`. Note that in this case, the prior is a list of elements containing the terms :math:`\|\left(Wx\right)_{i, j}\|_1=g_{i, j}(x)`, and that it is necessary that the dimension of the thresholding parameter matches that of :math:`g_{i, j}`. .. GENERATED FROM PYTHON SOURCE LINES 154-178 .. code-block:: Python # Unrolled optimization algorithm parameters. sigma_denoiser = [ torch.ones(1, 3, 3, device=device) * 0.01 # initialization of the regularization parameter. One thresholding parameter per wavelet sub-band and level. ] * max_iter # A distinct regularization parameter is trained for each iteration. trainable_params = [ "stepsize", "sigma_denoiser", ] # define which parameters are trainable # Define the unfolded trainable model. model = PGD( unfold=True, trainable_params=trainable_params, data_fidelity=data_fidelity, max_iter=max_iter, prior=prior, stepsize=stepsize, sigma_denoiser=sigma_denoiser, ).to(device) .. GENERATED FROM PYTHON SOURCE LINES 179-185 Define the training parameters. ------------------------------- We now define training-related parameters, number of epochs, optimizer (Adam) and its hyperparameters, and the train and test batch sizes. .. GENERATED FROM PYTHON SOURCE LINES 185-220 .. code-block:: Python # Training parameters epochs = 5 if torch.cuda.is_available() else 1 learning_rate = 1e-2 # Choose optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Choose supervised training loss losses = [dinv.loss.SupLoss(metric=dinv.metric.MSE())] # Logging parameters verbose = True # Batch sizes and data loaders train_batch_size = 128 if torch.cuda.is_available() else 2 test_batch_size = 128 if torch.cuda.is_available() else 8 train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True ) test_dataloader = DataLoader( test_dataset, batch_size=test_batch_size, num_workers=num_workers, shuffle=False ) # If working on CPU, start with a pretrained model to reduce training time if not torch.cuda.is_available(): file_name = "ckp_10_demo_LISTA.pth.tar" url = get_weights_url(model_name="demo", file_name=file_name) ckpt = torch.hub.load_state_dict_from_url( url, map_location=lambda storage, loc: storage, file_name=file_name ) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) .. GENERATED FROM PYTHON SOURCE LINES 221-226 Train the network. ------------------------------------------- We train the network using the :class:`deepinv.Trainer` class. .. GENERATED FROM PYTHON SOURCE LINES 226-244 .. code-block:: Python trainer = dinv.Trainer( model, physics=physics, train_dataloader=train_dataloader, eval_dataloader=test_dataloader, epochs=epochs, losses=losses, optimizer=optimizer, device=device, save_path=str(CKPT_DIR / operation), verbose=verbose, show_progress_bar=False, # disable progress bar for better vis in sphinx gallery. ) model = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1356: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. self.setup_train() The model has 100 trainable parameters Train epoch 0: TotalLoss=0.052, PSNR=13.052 Eval epoch 0: PSNR=13.589 Best model saved at epoch 1 Train epoch 1: TotalLoss=0.049, PSNR=13.291 Eval epoch 1: PSNR=13.682 Best model saved at epoch 2 Train epoch 2: TotalLoss=0.048, PSNR=13.421 Eval epoch 2: PSNR=13.736 Best model saved at epoch 3 Train epoch 3: TotalLoss=0.047, PSNR=13.501 Eval epoch 3: PSNR=13.771 Best model saved at epoch 4 Train epoch 4: TotalLoss=0.046, PSNR=13.556 Eval epoch 4: PSNR=13.798 Best model saved at epoch 5 .. GENERATED FROM PYTHON SOURCE LINES 245-252 Test the network. --------------------------- We now test the learned unrolled network on the test dataset. In the plotted results, the first column shows the measurements back-projected in the image domain, the second column shows the output of our LISTA network, and the third shows the ground truth. .. GENERATED FROM PYTHON SOURCE LINES 252-273 .. code-block:: Python trainer.test(test_dataloader) test_sample, _ = next(iter(test_dataloader)) model.eval() test_sample = test_sample.to(device) # Get the measurements and the ground truth y = physics(test_sample) with torch.no_grad(): # it is important to disable gradient computation during testing. rec = model(y, physics=physics) backprojected = physics.A_adjoint(y) dinv.utils.plot( [backprojected, rec, test_sample], titles=["Linear", "Reconstruction", "Ground truth"], suptitle="Reconstruction results", ) .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_LISTA_001.png :alt: Reconstruction results, Linear, Reconstruction, Ground truth :srcset: /auto_examples/unfolded/images/sphx_glr_demo_LISTA_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1548: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. self.setup_train(train=False) Eval epoch 0: PSNR=13.798, PSNR no learning=11.714 Test results: PSNR no learning: 11.714 +- 1.828 PSNR: 13.798 +- 1.568 /local/jtachell/deepinv/deepinv/deepinv/utils/plotting.py:408: UserWarning: This figure was using a layout engine that is incompatible with subplots_adjust and/or tight_layout; not calling subplots_adjust. fig.subplots_adjust(top=0.75) .. GENERATED FROM PYTHON SOURCE LINES 274-277 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 29.079 seconds) .. _sphx_glr_download_auto_examples_unfolded_demo_LISTA.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_LISTA.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_LISTA.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_LISTA.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_