.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unfolded/demo_LISTA.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unfolded_demo_LISTA.py: Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing ==================================================================================================== This example shows how to implement the LISTA algorithm :footcite:t:`gregor2010learning`, for a compressed sensing problem. In a nutshell, LISTA is an unfolded proximal gradient algorithm involving a soft-thresholding proximal operator with learnable thresholding parameters. .. GENERATED FROM PYTHON SOURCE LINES 10-22 .. code-block:: Python from pathlib import Path import torch from torchvision import datasets from torchvision import transforms import deepinv as dinv from torch.utils.data import DataLoader from deepinv.optim.data_fidelity import L2 from deepinv.unfolded import unfolded_builder from deepinv.utils.demo import get_data_home .. GENERATED FROM PYTHON SOURCE LINES 23-26 Setup paths for data loading and results. ----------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 26-38 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" CKPT_DIR = BASE_DIR / "ckpts" ORIGINAL_DATA_DIR = get_data_home() # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 39-42 Load base image datasets and degradation operators. ---------------------------------------------------------------------------------------- In this example, we use MNIST as the base dataset. .. GENERATED FROM PYTHON SOURCE LINES 42-58 .. code-block:: Python img_size = 28 n_channels = 1 operation = "compressed-sensing" train_dataset_name = "MNIST_train" # Generate training and evaluation datasets in HDF5 folders and load them. train_test_transform = transforms.Compose([transforms.ToTensor()]) train_base_dataset = datasets.MNIST( root=ORIGINAL_DATA_DIR, train=True, transform=train_test_transform, download=True ) test_base_dataset = datasets.MNIST( root=ORIGINAL_DATA_DIR, train=False, transform=train_test_transform, download=True ) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 404: Not Found Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to datasets/MNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/9912422 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_LISTA.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_LISTA.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_