DPIR method for PnP image deblurring.#

This example shows how to use the DPIR method to solve a PnP image deblurring problem. The DPIR method is described in Zhang et al.[1]. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 3929-3938).

import deepinv as dinv
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from deepinv.models import DRUNet
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.optim.optimizers import optim_builder
from deepinv.training import test
from torchvision import transforms
from deepinv.optim.dpir import get_DPIR_params
from deepinv.utils import load_dataset, load_degradation

Setup paths for data loading and results.#

BASE_DIR = Path(".")
DATA_DIR = BASE_DIR / "measurements"
RESULTS_DIR = BASE_DIR / "results"
DEG_DIR = BASE_DIR / "degradations"

Load base image datasets and degradation operators.#

In this example, we use the Set3C dataset and a motion blur kernel from Levin et al.[2].

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

# Set up the variable to fetch dataset and operators.
method = "DPIR"
dataset_name = "set3c"
img_size = 128 if torch.cuda.is_available() else 32
val_transform = transforms.Compose(
    [transforms.CenterCrop(img_size), transforms.ToTensor()]
)

# Generate a motion blur operator.
kernel_index = 1  # which kernel to chose among the 8 motion kernels from 'Levin09.mat'
kernel_torch = load_degradation("Levin09.npy", DEG_DIR / "kernels", index=kernel_index)
kernel_torch = kernel_torch.unsqueeze(0).unsqueeze(
    0
)  # add batch and channel dimensions
dataset = load_dataset(dataset_name, transform=val_transform)
Levin09.npy degradation downloaded in degradations/kernels
Downloading datasets/set3c.zip

  0%|          | 0.00/385k [00:00<?, ?iB/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 385k/385k [00:00<00:00, 69.2MiB/s]
set3c dataset downloaded in datasets

Generate a dataset of blurred images and load it.#

We use the BlurFFT class from the physics module to generate a dataset of blurred images.

noise_level_img = 0.03  # Gaussian Noise standard deviation for the degradation
n_channels = 3  # 3 for color images, 1 for gray-scale images
p = dinv.physics.BlurFFT(
    img_size=(n_channels, img_size, img_size),
    filter=kernel_torch,
    device=device,
    noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
)

# Use parallel dataloader if using a GPU to speed up training,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0

n_images_max = 3  # Maximal number of images to restore from the input dataset
# Generate a dataset in a HDF5 folder in "{dir}/dinv_dataset0.h5'" and load it.
operation = "deblur"
measurement_dir = DATA_DIR / dataset_name / operation
dinv_dataset_path = dinv.datasets.generate_dataset(
    train_dataset=dataset,
    test_dataset=None,
    physics=p,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    num_workers=num_workers,
)

batch_size = 3  # batch size for testing. As the number of iterations is fixed, we can use batch_size > 1
# and restore multiple images in parallel.
dataset = dinv.datasets.HDF5Dataset(path=dinv_dataset_path, train=True)
Dataset has been saved at measurements/set3c/deblur/dinv_dataset0.h5

Set up the DPIR algorithm to solve the inverse problem.#

This method is based on half-quadratic splitting (HQS). The algorithm alternates between a denoising step and a data fidelity step, where the denoising step is performed by a pretrained denoiser deepinv.models.DRUNet.

Note

We provide a wrapper for rapidly creating the DPIR algorithm in deepinv.optim.DPIR.

# load specific parameters for DPIR
sigma_denoiser, stepsize, max_iter = get_DPIR_params(noise_level_img, device=device)
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser}
early_stop = False  # Do not stop algorithm with convergence criteria

# Select the data fidelity term
data_fidelity = L2()

# Specify the denoising prior
prior = PnP(denoiser=DRUNet(pretrained="download", device=device))

# instantiate the algorithm class to solve the IP problem.
model = optim_builder(
    iteration="HQS",
    prior=prior,
    data_fidelity=data_fidelity,
    early_stop=early_stop,
    max_iter=max_iter,
    verbose=True,
    params_algo=params_algo,
)

# Set the model to evaluation mode. We do not require training here.
model.eval()
BaseOptim(
  (fixed_point): FixedPoint(
    (iterator): HQSIteration(
      (f_step): fStepHQS()
      (g_step): gStepHQS()
    )
  )
  (psnr): PSNR()
)

Evaluate the model on the problem.#

The test function evaluates the model on the test dataset and computes the metrics.

  • Ground truth, Measurement, No learning, Reconstruction
  • PSNR, residual
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /home/runner/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth

  0%|          | 0.00/233M [00:00<?, ?B/s]
 12%|β–ˆβ–        | 28.2M/233M [00:00<00:00, 296MB/s]
 25%|β–ˆβ–ˆβ–Œ       | 59.0M/233M [00:00<00:00, 311MB/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 88.8M/233M [00:00<00:00, 305MB/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 119M/233M [00:00<00:00, 310MB/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 149M/233M [00:00<00:00, 311MB/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 181M/233M [00:00<00:00, 314MB/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 211M/233M [00:00<00:00, 314MB/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 233M/233M [00:00<00:00, 311MB/s]
Downloading: "https://huggingface.co/chaofengc/IQA-PyTorch-Weights/resolve/main/LPIPS_v0.1_alex-df73285e.pth" to /home/runner/.cache/torch/hub/pyiqa/LPIPS_v0.1_alex-df73285e.pth


  0%|          | 0.00/5.87k [00:00<?, ?B/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5.87k/5.87k [00:00<00:00, 39.4MB/s]
Loading pretrained model LPIPS from /home/runner/.cache/torch/hub/pyiqa/LPIPS_v0.1_alex-df73285e.pth
/home/runner/work/deepinv/deepinv/deepinv/training/trainer.py:522: UserWarning: Update progress bar frequency of 1 may slow down training on GPU. Consider setting freq_update_progress_bar > 1.
  warnings.warn(

  0%|          | 0/1 [00:00<?, ?it/s]
Test:   0%|          | 0/1 [00:00<?, ?it/s]
Test:   0%|          | 0/1 [00:00<?, ?it/s, PSNR=30.2, PSNR no learning=17, LPIPS=0.00378, LPIPS no learning=0.237]
Test: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:01<00:00,  1.06s/it, PSNR=30.2, PSNR no learning=17, LPIPS=0.00378, LPIPS no learning=0.237]
Test: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:01<00:00,  1.06s/it, PSNR=30.2, PSNR no learning=17, LPIPS=0.00378, LPIPS no learning=0.237]
Test results:
PSNR no learning: 16.968 +- 0.647
PSNR: 30.165 +- 2.016
LPIPS no learning: 0.237 +- 0.043
LPIPS: 0.004 +- 0.003

{'PSNR no learning': 16.96754201253255, 'PSNR no learning_std': 0.6471807496631378, 'PSNR': 30.164637247721355, 'PSNR_std': 2.0164762930579503, 'LPIPS no learning': 0.2374823490778605, 'LPIPS no learning_std': 0.042793832110393296, 'LPIPS': 0.003775439535578092, 'LPIPS_std': 0.002592691248814548}
References:

Total running time of the script: (0 minutes 3.235 seconds)

Gallery generated by Sphinx-Gallery