Note
Go to the end to download the full example code.
Vanilla Unfolded algorithm for super-resolution
This is a simple example to show how to use vanilla unfolded Plug-and-Play.
The DnCNN denoiser and the algorithm parameters (stepsize, regularization parameters) are trained jointly.
For simplicity, we show how to train the algorithm on a small dataset. For optimal results, use a larger dataset.
For visualizing the training, you can use Weight&Bias (wandb) by setting wandb_vis=True
.
import deepinv as dinv
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.unfolded import unfolded_builder
from torchvision import transforms
from deepinv.utils.demo import load_dataset
Setup paths for data loading and results.
BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
DATA_DIR = BASE_DIR / "measurements"
RESULTS_DIR = BASE_DIR / "results"
CKPT_DIR = BASE_DIR / "ckpts"
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load base image datasets and degradation operators.
In this example, we use the CBSD500 dataset for training and the Set3C dataset for testing.
img_size = 64 if torch.cuda.is_available() else 32
n_channels = 3 # 3 for color images, 1 for gray-scale images
operation = "super-resolution"
Generate a dataset of low resolution images and load it.
We use the Downsampling class from the physics module to generate a dataset of low resolution images.
# For simplicity, we use a small dataset for training.
# To be replaced for optimal results. For example, you can use the larger "drunet" dataset.
train_dataset_name = "CBSD500"
test_dataset_name = "set3c"
# Specify the train and test transforms to be applied to the input images.
test_transform = transforms.Compose(
[transforms.CenterCrop(img_size), transforms.ToTensor()]
)
train_transform = transforms.Compose(
[transforms.RandomCrop(img_size), transforms.ToTensor()]
)
# Define the base train and test datasets of clean images.
train_base_dataset = load_dataset(
train_dataset_name, ORIGINAL_DATA_DIR, transform=train_transform
)
test_base_dataset = load_dataset(
test_dataset_name, ORIGINAL_DATA_DIR, transform=test_transform
)
# Use parallel dataloader if using a GPU to fasten training, otherwise, as all computes are on CPU, use synchronous
# dataloading.
num_workers = 4 if torch.cuda.is_available() else 0
# Degradation parameters
factor = 2
noise_level_img = 0.03
# Generate the gaussian blur downsampling operator.
physics = dinv.physics.Downsampling(
filter="gaussian",
img_size=(n_channels, img_size, img_size),
factor=factor,
device=device,
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
)
my_dataset_name = "demo_unfolded_sr"
n_images_max = (
1000 if torch.cuda.is_available() else 10
) # maximal number of images used for training
measurement_dir = DATA_DIR / train_dataset_name / operation
generated_datasets_path = dinv.datasets.generate_dataset(
train_dataset=train_base_dataset,
test_dataset=test_base_dataset,
physics=physics,
device=device,
save_dir=measurement_dir,
train_datapoints=n_images_max,
num_workers=num_workers,
dataset_filename=str(my_dataset_name),
)
train_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=False)
Downloading datasets/CBSD500.zip
0%| | 0.00/71.0M [00:00<?, ?iB/s]
2%|▏ | 1.08M/71.0M [00:00<00:06, 10.6MiB/s]
3%|▎ | 2.14M/71.0M [00:00<00:06, 10.5MiB/s]
5%|▍ | 3.20M/71.0M [00:00<00:06, 10.6MiB/s]
6%|▌ | 4.26M/71.0M [00:00<00:06, 10.5MiB/s]
7%|▋ | 5.32M/71.0M [00:00<00:06, 10.5MiB/s]
9%|▉ | 6.37M/71.0M [00:00<00:06, 10.5MiB/s]
10%|█ | 7.42M/71.0M [00:00<00:06, 10.4MiB/s]
12%|█▏ | 8.46M/71.0M [00:00<00:06, 10.4MiB/s]
13%|█▎ | 9.50M/71.0M [00:00<00:05, 10.4MiB/s]
15%|█▍ | 10.5M/71.0M [00:01<00:05, 10.4MiB/s]
16%|█▋ | 11.6M/71.0M [00:01<00:05, 10.4MiB/s]
18%|█▊ | 12.6M/71.0M [00:01<00:05, 10.4MiB/s]
19%|█▉ | 13.7M/71.0M [00:01<00:05, 10.4MiB/s]
21%|██ | 14.7M/71.0M [00:01<00:05, 10.4MiB/s]
22%|██▏ | 15.8M/71.0M [00:01<00:05, 10.4MiB/s]
24%|██▎ | 16.8M/71.0M [00:01<00:05, 10.4MiB/s]
25%|██▌ | 17.9M/71.0M [00:01<00:05, 10.4MiB/s]
27%|██▋ | 18.9M/71.0M [00:01<00:05, 10.3MiB/s]
28%|██▊ | 20.0M/71.0M [00:01<00:04, 10.4MiB/s]
30%|██▉ | 21.0M/71.0M [00:02<00:04, 10.4MiB/s]
31%|███ | 22.1M/71.0M [00:02<00:04, 10.4MiB/s]
33%|███▎ | 23.1M/71.0M [00:02<00:04, 10.4MiB/s]
34%|███▍ | 24.2M/71.0M [00:02<00:04, 10.4MiB/s]
36%|███▌ | 25.2M/71.0M [00:02<00:04, 10.4MiB/s]
37%|███▋ | 26.3M/71.0M [00:02<00:04, 10.4MiB/s]
38%|███▊ | 27.3M/71.0M [00:02<00:04, 10.4MiB/s]
40%|███▉ | 28.3M/71.0M [00:02<00:04, 10.3MiB/s]
41%|████▏ | 29.4M/71.0M [00:02<00:04, 10.3MiB/s]
43%|████▎ | 30.4M/71.0M [00:02<00:03, 10.4MiB/s]
44%|████▍ | 31.5M/71.0M [00:03<00:03, 10.4MiB/s]
46%|████▌ | 32.6M/71.0M [00:03<00:03, 10.5MiB/s]
47%|████▋ | 33.6M/71.0M [00:03<00:03, 10.4MiB/s]
49%|████▉ | 34.7M/71.0M [00:03<00:03, 10.4MiB/s]
50%|█████ | 35.7M/71.0M [00:03<00:03, 10.4MiB/s]
52%|█████▏ | 36.7M/71.0M [00:03<00:03, 10.4MiB/s]
53%|█████▎ | 37.8M/71.0M [00:03<00:03, 10.4MiB/s]
55%|█████▍ | 38.8M/71.0M [00:03<00:03, 10.4MiB/s]
56%|█████▌ | 39.9M/71.0M [00:03<00:02, 10.5MiB/s]
58%|█████▊ | 41.0M/71.0M [00:03<00:02, 10.5MiB/s]
59%|█████▉ | 42.0M/71.0M [00:04<00:02, 10.5MiB/s]
61%|██████ | 43.1M/71.0M [00:04<00:02, 10.7MiB/s]
62%|██████▏ | 44.2M/71.0M [00:04<00:02, 10.6MiB/s]
64%|██████▍ | 45.3M/71.0M [00:04<00:02, 10.6MiB/s]
65%|██████▌ | 46.3M/71.0M [00:04<00:02, 10.6MiB/s]
67%|██████▋ | 47.4M/71.0M [00:04<00:02, 10.6MiB/s]
68%|██████▊ | 48.4M/71.0M [00:04<00:02, 10.5MiB/s]
70%|██████▉ | 49.5M/71.0M [00:04<00:02, 10.5MiB/s]
71%|███████ | 50.5M/71.0M [00:04<00:01, 10.5MiB/s]
73%|███████▎ | 51.6M/71.0M [00:04<00:01, 10.6MiB/s]
74%|███████▍ | 52.7M/71.0M [00:05<00:01, 10.5MiB/s]
76%|███████▌ | 53.7M/71.0M [00:05<00:01, 10.4MiB/s]
77%|███████▋ | 54.8M/71.0M [00:05<00:01, 9.95MiB/s]
79%|███████▊ | 55.9M/71.0M [00:05<00:01, 10.1MiB/s]
80%|████████ | 56.9M/71.0M [00:05<00:01, 10.2MiB/s]
82%|████████▏ | 58.0M/71.0M [00:05<00:01, 10.3MiB/s]
83%|████████▎ | 59.0M/71.0M [00:05<00:01, 10.3MiB/s]
85%|████████▍ | 60.1M/71.0M [00:05<00:01, 10.3MiB/s]
86%|████████▌ | 61.1M/71.0M [00:05<00:00, 10.3MiB/s]
88%|████████▊ | 62.2M/71.0M [00:05<00:00, 10.4MiB/s]
89%|████████▉ | 63.2M/71.0M [00:06<00:00, 10.4MiB/s]
91%|█████████ | 64.3M/71.0M [00:06<00:00, 10.4MiB/s]
92%|█████████▏| 65.4M/71.0M [00:06<00:00, 10.6MiB/s]
94%|█████████▎| 66.4M/71.0M [00:06<00:00, 10.6MiB/s]
95%|█████████▌| 67.5M/71.0M [00:06<00:00, 10.6MiB/s]
97%|█████████▋| 68.5M/71.0M [00:06<00:00, 10.5MiB/s]
98%|█████████▊| 69.6M/71.0M [00:06<00:00, 10.4MiB/s]
100%|█████████▉| 70.6M/71.0M [00:06<00:00, 10.4MiB/s]
100%|██████████| 71.0M/71.0M [00:06<00:00, 10.4MiB/s]
CBSD500 dataset downloaded in datasets
Downloading datasets/set3c.zip
0%| | 0.00/385k [00:00<?, ?iB/s]
100%|██████████| 385k/385k [00:00<00:00, 13.2MiB/s]
set3c dataset downloaded in datasets
Dataset has been saved in measurements/CBSD500/super-resolution
Define the unfolded PnP algorithm.
We use the helper function deepinv.unfolded.unfolded_builder()
to defined the Unfolded architecture.
The chosen algorithm is here DRS (Douglas-Rachford Splitting).
Note that if the prior (resp. a parameter) is initialized with a list of lenght max_iter,
then a distinct model (resp. parameter) is trained for each iteration.
For fixed trained model prior (resp. parameter) across iterations, initialize with a single element.
# Unrolled optimization algorithm parameters
max_iter = 5 # number of unfolded layers
# Select the data fidelity term
data_fidelity = L2()
# Set up the trainable denoising prior
# Here the prior model is common for all iterations
prior = PnP(denoiser=dinv.models.DnCNN(depth=7, pretrained=None).to(device))
# The parameters are initialized with a list of length max_iter, so that a distinct parameter is trained for each iteration.
stepsize = [1] * max_iter # stepsize of the algorithm
sigma_denoiser = [0.01] * max_iter # noise level parameter of the denoiser
beta = 1 # relaxation parameter of the Douglas-Rachford splitting
params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
"stepsize": stepsize,
"g_param": sigma_denoiser,
"beta": beta,
}
trainable_params = [
"g_param",
"stepsize",
"beta",
] # define which parameters from 'params_algo' are trainable
# Logging parameters
verbose = True
wandb_vis = False # plot curves and images in Weight&Bias
# Define the unfolded trainable model.
model = unfolded_builder(
iteration="DRS",
params_algo=params_algo.copy(),
trainable_params=trainable_params,
data_fidelity=data_fidelity,
max_iter=max_iter,
prior=prior,
)
Define the training parameters.
We use the Adam optimizer and the StepLR scheduler.
# training parameters
epochs = 10 if torch.cuda.is_available() else 2
learning_rate = 5e-4
train_batch_size = 32 if torch.cuda.is_available() else 1
test_batch_size = 3
# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8))
# choose supervised training loss
losses = [dinv.loss.SupLoss(metric=dinv.metric.MSE())]
train_dataloader = DataLoader(
train_dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(
test_dataset, batch_size=test_batch_size, num_workers=num_workers, shuffle=False
)
Train the network
We train the network using the deepinv.Trainer()
class.
trainer = dinv.Trainer(
model,
physics=physics,
train_dataloader=train_dataloader,
eval_dataloader=test_dataloader,
epochs=epochs,
scheduler=scheduler,
losses=losses,
optimizer=optimizer,
device=device,
save_path=str(CKPT_DIR / operation),
verbose=verbose,
show_progress_bar=False, # disable progress bar for better vis in sphinx gallery.
wandb_vis=wandb_vis, # training visualization can be done in Weight&Bias
)
model = trainer.train()
The model has 188174 trainable parameters
Train epoch 0: TotalLoss=140.864, PSNR=-0.094
Eval epoch 0: PSNR=2.717
Train epoch 1: TotalLoss=0.276, PSNR=6.355
Eval epoch 1: PSNR=3.146
Test the network
trainer.test(test_dataloader)
Eval epoch 0: PSNR=3.146, PSNR no learning=5.996
Test results:
PSNR no learning: 5.996 +- 1.188
PSNR: 3.146 +- 1.309
{'PSNR no learning': 5.9960276285807295, 'PSNR no learning_std': 1.1883839260449329, 'PSNR': 3.1455907821655273, 'PSNR_std': 1.3087291770605716}
Plotting the trained parameters.
dinv.utils.plotting.plot_parameters(
model, init_params=params_algo, save_dir=RESULTS_DIR / "unfolded_drs" / operation
)
Total running time of the script: (0 minutes 9.651 seconds)