Vanilla Unfolded algorithm for super-resolution#

This is a simple example to show how to use vanilla unfolded Plug-and-Play. The DnCNN denoiser and the algorithm parameters (stepsize, regularization parameters) are trained jointly. For simplicity, we show how to train the algorithm on a small dataset. For optimal results, use a larger dataset.

import deepinv as dinv
import torch
from deepinv.models.utils import get_weights_url
from torch.utils.data import DataLoader
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.optim import DRS
from torchvision import transforms
from deepinv.utils import get_data_home
from deepinv.datasets import BSDS500

Setup paths for data loading and results.#

BASE_DIR = get_data_home()
DATA_DIR = BASE_DIR / "measurements"
RESULTS_DIR = BASE_DIR / "results"
CKPT_DIR = BASE_DIR / "ckpts"

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_device()
Selected GPU 0 with 8069.25 MiB free memory

Load base image datasets and degradation operators.#

In this example, we use the CBSD500 dataset for training and the Set3C dataset for testing.

img_size = 64 if torch.cuda.is_available() else 32
n_channels = 3  # 3 for color images, 1 for gray-scale images
operation = "super-resolution"

Generate a dataset of low resolution images and load it.#

We use the Downsampling class from the physics module to generate a dataset of low resolution images.

# For simplicity, we use a small dataset for training.
# To be replaced for optimal results. For example, you can use the larger DIV2K or LSDIR datasets (also provided in the library).

# Specify the  train and test transforms to be applied to the input images.
test_transform = transforms.Compose(
    [transforms.CenterCrop(img_size), transforms.ToTensor()]
)
train_transform = transforms.Compose(
    [transforms.RandomCrop(img_size), transforms.ToTensor()]
)
# Define the base train and test datasets of clean images.
train_base_dataset = BSDS500(
    BASE_DIR, download=True, train=True, transform=train_transform
)
test_base_dataset = BSDS500(
    BASE_DIR, download=False, train=False, transform=test_transform
)

# Use parallel dataloader if using a GPU to speed up training, otherwise, as all computes are on CPU, use synchronous
# dataloading.
num_workers = 4 if torch.cuda.is_available() else 0

# Degradation parameters
factor = 2
noise_level_img = 0.03

# Generate the gaussian blur downsampling operator.
physics = dinv.physics.Downsampling(
    filter="gaussian",
    img_size=(n_channels, img_size, img_size),
    factor=factor,
    device=device,
    noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
)
my_dataset_name = "demo_unfolded_sr"
n_images_max = (
    None if torch.cuda.is_available() else 10
)  # max number of images used for training (use all if you have a GPU)
measurement_dir = DATA_DIR / "BSDS500" / operation
generated_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_base_dataset,
    test_dataset=test_base_dataset,
    physics=physics,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    num_workers=num_workers,
    dataset_filename=str(my_dataset_name),
)

train_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=False)
  0%|          | 0/167818480 [00:00<?, ?it/s]
  1%|          | 896k/160M [00:00<00:18, 9.04MB/s]
  4%|▎         | 6.00M/160M [00:00<00:04, 32.6MB/s]
  6%|▋         | 10.1M/160M [00:00<00:04, 35.7MB/s]
  8%|▊         | 13.5M/160M [00:00<00:04, 35.4MB/s]
 11%|█         | 17.0M/160M [00:00<00:04, 35.5MB/s]
 13%|█▎        | 20.9M/160M [00:00<00:03, 36.5MB/s]
 15%|█▌        | 24.8M/160M [00:00<00:03, 37.6MB/s]
 18%|█▊        | 28.8M/160M [00:00<00:03, 38.1MB/s]
 20%|██        | 32.8M/160M [00:00<00:03, 38.6MB/s]
 23%|██▎       | 36.8M/160M [00:01<00:03, 39.2MB/s]
 26%|██▌       | 40.8M/160M [00:01<00:03, 39.6MB/s]
 28%|██▊       | 44.9M/160M [00:01<00:02, 40.3MB/s]
 31%|███       | 48.9M/160M [00:01<00:02, 40.7MB/s]
 33%|███▎      | 52.8M/160M [00:01<00:02, 40.6MB/s]
 35%|███▌      | 56.8M/160M [00:01<00:02, 40.5MB/s]
 38%|███▊      | 60.9M/160M [00:01<00:02, 40.9MB/s]
 41%|████      | 64.9M/160M [00:01<00:02, 41.2MB/s]
 43%|████▎     | 68.9M/160M [00:01<00:02, 41.4MB/s]
 46%|████▌     | 72.9M/160M [00:01<00:02, 41.1MB/s]
 48%|████▊     | 77.1M/160M [00:02<00:02, 41.7MB/s]
 51%|█████     | 81.1M/160M [00:02<00:01, 41.8MB/s]
 53%|█████▎    | 85.2M/160M [00:02<00:01, 42.1MB/s]
 56%|█████▌    | 89.4M/160M [00:02<00:01, 42.4MB/s]
 58%|█████▊    | 93.4M/160M [00:02<00:01, 41.9MB/s]
 61%|██████    | 97.6M/160M [00:02<00:01, 42.4MB/s]
 64%|██████▎   | 102M/160M [00:02<00:01, 42.2MB/s]
 66%|██████▌   | 106M/160M [00:02<00:01, 41.6MB/s]
 69%|██████▊   | 110M/160M [00:02<00:01, 41.6MB/s]
 71%|███████▏  | 114M/160M [00:02<00:01, 42.6MB/s]
 74%|███████▍  | 118M/160M [00:03<00:01, 42.1MB/s]
 77%|███████▋  | 123M/160M [00:03<00:00, 43.0MB/s]
 79%|███████▉  | 127M/160M [00:03<00:00, 42.3MB/s]
 82%|████████▏ | 131M/160M [00:03<00:00, 42.8MB/s]
 84%|████████▍ | 135M/160M [00:03<00:00, 42.6MB/s]
 87%|████████▋ | 139M/160M [00:03<00:00, 42.2MB/s]
 90%|████████▉ | 144M/160M [00:03<00:00, 43.1MB/s]
 92%|█████████▏| 148M/160M [00:03<00:00, 42.0MB/s]
 95%|█████████▌| 152M/160M [00:03<00:00, 43.0MB/s]
 98%|█████████▊| 156M/160M [00:04<00:00, 42.7MB/s]
100%|██████████| 160M/160M [00:04<00:00, 40.7MB/s]

Extracting:   0%|          | 0/2492 [00:00<?, ?it/s]
Extracting:  13%|█▎        | 329/2492 [00:00<00:00, 3284.44it/s]
Extracting:  27%|██▋       | 661/2492 [00:00<00:00, 3305.17it/s]
Extracting:  40%|████      | 1002/2492 [00:00<00:00, 3352.00it/s]
Extracting:  54%|█████▎    | 1338/2492 [00:00<00:00, 1522.43it/s]
Extracting:  63%|██████▎   | 1574/2492 [00:00<00:00, 1288.56it/s]
Extracting:  71%|███████   | 1759/2492 [00:01<00:00, 1259.18it/s]
Extracting: 100%|██████████| 2492/2492 [00:01<00:00, 2119.41it/s]
Dataset has been saved at datasets/measurements/BSDS500/super-resolution/demo_unfolded_sr0.h5

Define the unfolded PnP algorithm.#

The chosen algorithm is here DRS (Douglas-Rachford Splitting). Note that if the prior (resp. a parameter) is initialized with a list of length max_iter, then a distinct model (resp. parameter) is trained for each iteration. For fixed trained model prior (resp. parameter) across iterations, initialize with a single element.

# Unrolled optimization algorithm parameters
max_iter = 5  # number of unfolded layers

# Select the data fidelity term
data_fidelity = L2()

# Set up the trainable denoising prior
# Here the prior model is common for all iterations
prior = PnP(denoiser=dinv.models.DnCNN(depth=20, pretrained="download").to(device))

# The parameters are initialized with a list of length max_iter, so that a distinct parameter is trained for each iteration.
stepsize = [1.0] * max_iter  # stepsize of the algorithm
sigma_denoiser = [
    1.0
] * max_iter  # noise level parameter of the denoiser (not used by DnCNN)
beta = 1.0  # relaxation parameter of the Douglas-Rachford splitting
trainable_params = [
    "stepsize",
    "beta",
    "sigma_denoiser",
]  # define which parameters are trainable

# Logging parameters
verbose = True

# Define the unfolded trainable model.
model = DRS(
    stepsize=stepsize,
    sigma_denoiser=sigma_denoiser,
    beta=beta,
    trainable_params=trainable_params,
    data_fidelity=data_fidelity,
    max_iter=max_iter,
    prior=prior,
    unfold=True,
)

Define the training parameters.#

We use the Adam optimizer and the StepLR scheduler.

# training parameters
epochs = 5 if torch.cuda.is_available() else 1
learning_rate = 5e-4
train_batch_size = 32 if torch.cuda.is_available() else 1
test_batch_size = 3

# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)

# If working on CPU, start with a pretrained model to reduce training time
if not torch.cuda.is_available():
    file_name = "demo_vanilla_unfolded.pth"
    url = get_weights_url(model_name="demo", file_name=file_name)
    ckpt = torch.hub.load_state_dict_from_url(
        url, map_location=lambda storage, loc: storage, file_name=file_name
    )
    model.load_state_dict(ckpt["state_dict"])
    optimizer.load_state_dict(ckpt["optimizer"])

# choose supervised training loss
losses = [dinv.loss.SupLoss(metric=dinv.metric.MSE())]

train_dataloader = DataLoader(
    train_dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=test_batch_size, num_workers=num_workers, shuffle=False
)

Train the network#

We train the network using the deepinv.Trainer class.

trainer = dinv.Trainer(
    model,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=epochs,
    losses=losses,
    optimizer=optimizer,
    device=device,
    early_stop=True,  # set to None to disable early stopping
    save_path=str(CKPT_DIR / operation),
    verbose=verbose,
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
)

model = trainer.train()
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1354: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute.
  self.setup_train()
The model has 668238 trainable parameters
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:521: UserWarning: early_stop should be an integer or None. Setting early_stop=3. This behaviour will be deprecated in future versions.
  warnings.warn(
Train epoch 0: TotalLoss=0.009, PSNR=21.59
Eval epoch 0: PSNR=20.576
Best model saved at epoch 1
Train epoch 1: TotalLoss=0.007, PSNR=22.994
Eval epoch 1: PSNR=21.455
Best model saved at epoch 2
Train epoch 2: TotalLoss=0.007, PSNR=23.59
Eval epoch 2: PSNR=21.325
Train epoch 3: TotalLoss=0.006, PSNR=23.62
Eval epoch 3: PSNR=21.584
Best model saved at epoch 4
Train epoch 4: TotalLoss=0.006, PSNR=23.961
Eval epoch 4: PSNR=21.891
Best model saved at epoch 5

Test the network#

trainer.test(test_dataloader)

test_sample, _ = next(iter(test_dataloader))
model.eval()
test_sample = test_sample.to(device)

# Get the measurements and the ground truth
y = physics(test_sample)
with torch.no_grad():
    rec = model(y, physics=physics)

backprojected = physics.A_adjoint(y)

dinv.utils.plot(
    [backprojected, rec, test_sample],
    titles=["Linear", "Reconstruction", "Ground truth"],
    suptitle="Reconstruction results",
)
Reconstruction results, Linear, Reconstruction, Ground truth
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1546: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute.
  self.setup_train(train=False)
Eval epoch 0: PSNR=21.891, PSNR no learning=9.122
Test results:
PSNR no learning: 9.122 +- 2.903
PSNR: 21.891 +- 3.692
/local/jtachell/deepinv/deepinv/deepinv/utils/plotting.py:408: UserWarning: This figure was using a layout engine that is incompatible with subplots_adjust and/or tight_layout; not calling subplots_adjust.
  fig.subplots_adjust(top=0.75)

Total running time of the script: (0 minutes 38.525 seconds)

Gallery generated by Sphinx-Gallery