.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unfolded/demo_vanilla_unfolded.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unfolded_demo_vanilla_unfolded.py: Vanilla Unfolded algorithm for super-resolution ==================================================================================================== This is a simple example to show how to use vanilla unfolded Plug-and-Play. The DnCNN denoiser and the algorithm parameters (stepsize, regularization parameters) are trained jointly. For simplicity, we show how to train the algorithm on a small dataset. For optimal results, use a larger dataset. .. GENERATED FROM PYTHON SOURCE LINES 11-22 .. code-block:: Python import deepinv as dinv import torch from deepinv.models.utils import get_weights_url from torch.utils.data import DataLoader from deepinv.optim.data_fidelity import L2 from deepinv.optim.prior import PnP from deepinv.optim import DRS from torchvision import transforms from deepinv.utils import get_data_home from deepinv.datasets import BSDS500 .. GENERATED FROM PYTHON SOURCE LINES 23-26 Setup paths for data loading and results. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 26-37 .. code-block:: Python BASE_DIR = get_data_home() DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" CKPT_DIR = BASE_DIR / "ckpts" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_device() .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 8069.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 38-41 Load base image datasets and degradation operators. ---------------------------------------------------------------------------------------- In this example, we use the CBSD500 dataset for training and the Set3C dataset for testing. .. GENERATED FROM PYTHON SOURCE LINES 41-46 .. code-block:: Python img_size = 64 if torch.cuda.is_available() else 32 n_channels = 3 # 3 for color images, 1 for gray-scale images operation = "super-resolution" .. GENERATED FROM PYTHON SOURCE LINES 47-50 Generate a dataset of low resolution images and load it. ---------------------------------------------------------------------------------------- We use the Downsampling class from the physics module to generate a dataset of low resolution images. .. GENERATED FROM PYTHON SOURCE LINES 50-104 .. code-block:: Python # For simplicity, we use a small dataset for training. # To be replaced for optimal results. For example, you can use the larger DIV2K or LSDIR datasets (also provided in the library). # Specify the train and test transforms to be applied to the input images. test_transform = transforms.Compose( [transforms.CenterCrop(img_size), transforms.ToTensor()] ) train_transform = transforms.Compose( [transforms.RandomCrop(img_size), transforms.ToTensor()] ) # Define the base train and test datasets of clean images. train_base_dataset = BSDS500( BASE_DIR, download=True, train=True, transform=train_transform ) test_base_dataset = BSDS500( BASE_DIR, download=False, train=False, transform=test_transform ) # Use parallel dataloader if using a GPU to speed up training, otherwise, as all computes are on CPU, use synchronous # dataloading. num_workers = 4 if torch.cuda.is_available() else 0 # Degradation parameters factor = 2 noise_level_img = 0.03 # Generate the gaussian blur downsampling operator. physics = dinv.physics.Downsampling( filter="gaussian", img_size=(n_channels, img_size, img_size), factor=factor, device=device, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), ) my_dataset_name = "demo_unfolded_sr" n_images_max = ( None if torch.cuda.is_available() else 10 ) # max number of images used for training (use all if you have a GPU) measurement_dir = DATA_DIR / "BSDS500" / operation generated_datasets_path = dinv.datasets.generate_dataset( train_dataset=train_base_dataset, test_dataset=test_base_dataset, physics=physics, device=device, save_dir=measurement_dir, train_datapoints=n_images_max, num_workers=num_workers, dataset_filename=str(my_dataset_name), ) train_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=True) test_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=False) .. rst-class:: sphx-glr-script-out .. code-block:: none 0it [00:00, ?it/s] 960kB [00:00, 9.78MB/s] 4.81MB [00:00, 27.8MB/s] 7.50MB [00:00, 25.6MB/s] 10.0MB [00:00, 25.1MB/s] 12.4MB [00:00, 23.9MB/s] 14.8MB [00:00, 21.8MB/s] 16.9MB [00:00, 21.3MB/s] 19.2MB [00:00, 22.2MB/s] 21.4MB [00:00, 22.6MB/s] 24.2MB [00:01, 24.6MB/s] 26.6MB [00:01, 24.5MB/s] 29.4MB [00:01, 25.6MB/s] 32.2MB [00:01, 26.7MB/s] 34.8MB [00:01, 26.4MB/s] 37.5MB [00:01, 26.9MB/s] 40.2MB [00:01, 27.3MB/s] 42.8MB [00:01, 26.3MB/s] 45.6MB [00:01, 27.0MB/s] 48.2MB [00:02, 26.3MB/s] 50.8MB [00:02, 25.6MB/s] 54.1MB [00:02, 28.0MB/s] 56.9MB [00:02, 24.4MB/s] 59.3MB [00:02, 22.1MB/s] 61.6MB [00:02, 20.1MB/s] 63.6MB [00:02, 18.3MB/s] 65.4MB [00:02, 17.1MB/s] 67.1MB [00:03, 16.4MB/s] 68.7MB [00:03, 15.7MB/s] 70.2MB [00:03, 15.1MB/s] 71.8MB [00:03, 14.7MB/s] 73.2MB [00:03, 14.7MB/s] 74.9MB [00:03, 15.5MB/s] 76.4MB [00:03, 15.5MB/s] 77.9MB [00:03, 15.3MB/s] 79.5MB [00:03, 15.6MB/s] 81.1MB [00:04, 15.9MB/s] 82.9MB [00:04, 16.5MB/s] 84.6MB [00:04, 16.7MB/s] 86.2MB [00:04, 16.6MB/s] 87.9MB [00:04, 16.5MB/s] 89.5MB [00:04, 16.7MB/s] 91.2MB [00:04, 17.1MB/s] 92.9MB [00:04, 16.9MB/s] 94.6MB [00:04, 17.1MB/s] 96.4MB [00:04, 17.3MB/s] 98.1MB [00:05, 17.1MB/s] 99.9MB [00:05, 17.5MB/s] 102MB [00:05, 17.3MB/s] 103MB [00:05, 17.0MB/s] 105MB [00:05, 17.1MB/s] 107MB [00:05, 17.0MB/s] 108MB [00:05, 17.2MB/s] 110MB [00:05, 17.5MB/s] 112MB [00:05, 17.5MB/s] 114MB [00:06, 17.2MB/s] 115MB [00:06, 16.9MB/s] 117MB [00:06, 16.0MB/s] 118MB [00:06, 16.1MB/s] 120MB [00:06, 15.9MB/s] 122MB [00:06, 16.8MB/s] 124MB [00:06, 17.1MB/s] 125MB [00:06, 17.2MB/s] 127MB [00:06, 17.0MB/s] 129MB [00:07, 17.0MB/s] 131MB [00:07, 17.2MB/s] 133MB [00:07, 18.1MB/s] 134MB [00:07, 8.96MB/s] 139MB [00:07, 15.5MB/s] 143MB [00:07, 20.6MB/s] 146MB [00:07, 21.2MB/s] 150MB [00:08, 26.3MB/s] 154MB [00:08, 27.8MB/s] 157MB [00:08, 14.4MB/s] 159MB [00:09, 12.3MB/s] 160MB [00:09, 18.4MB/s] Extracting: 0%| | 0/2492 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_vanilla_unfolded.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_vanilla_unfolded.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_