Self-supervised denoising with the SURE loss.

This example shows you how to train a denoiser network in a fully self-supervised way, i.e., using noisy images only via the SURE loss, which exploits knowledge about the noise distribution.

The SURE loss for Poisson denoising acts as an unbiased estimator of the supervised loss and is computed as:

\[\frac{1}{m}\|y-\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y +\frac{2\gamma}{m\tau}(b\odot y)^{\top} \left(\inverse{y+\tau b}-\inverse{y}\right)\]

where \(R\) is the trainable network, \(y\) is the noisy image with \(m\) pixels, \(b\) is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5, \(\tau\) is a small positive number, and \(\odot\) is an elementwise multiplication.

import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from pathlib import Path
from torchvision import transforms, datasets
from deepinv.models.utils import get_weights_url

Setup paths for data loading and results.

BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load base image datasets

In this example, we use the MNIST dataset as the base image dataset.

operation = "denoising"
train_dataset_name = "MNIST"

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(
    root="../datasets/", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root="../datasets/", train=False, transform=transform, download=True
)

Generate a dataset of noisy images

We generate a dataset of noisy images corrupted by Poisson noise.

Note

We use a subset of the whole training set to reduce the computational load of the example. We recommend to use the whole set by setting n_images_max=None to get the best results.

# defined physics
physics = dinv.physics.Denoising(dinv.physics.PoissonNoise(0.1))

# Use parallel dataloader if using a GPU to fasten training,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0

n_images_max = (
    100 if torch.cuda.is_available() else 5
)  # number of images used for training

measurement_dir = DATA_DIR / train_dataset_name / operation
deepinv_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    test_datapoints=n_images_max,
    num_workers=num_workers,
    dataset_filename="demo_sure",
)

train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)
Dataset has been saved in measurements/MNIST/denoising

Set up the denoiser network

We use a simple U-Net architecture with 2 scales as the denoiser network.

model = dinv.models.ArtifactRemoval(
    dinv.models.UNet(in_channels=1, out_channels=1, scales=2).to(device)
)

Set up the training parameters

We set deepinv.loss.SurePoissonLoss as the training loss.

Note

There are SURE losses for various noise distributions. See also deepinv.loss.SureGaussianLoss for Gaussian noise and deepinv.loss.SurePGLoss for mixed Poisson-Gaussian noise.

Note

We use a pretrained model to reduce training time. You can get the same results by training from scratch for 10 epochs.

epochs = 1  # choose training epochs
learning_rate = 5e-4
batch_size = 32 if torch.cuda.is_available() else 1

# choose self-supervised training loss
loss = dinv.loss.SurePoissonLoss(gain=0.1)

# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8) + 1)

# start with a pretrained model to reduce training time
file_name = "ckp_10_demo_sure.pth"
url = get_weights_url(model_name="demo", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
    url, map_location=lambda storage, loc: storage, file_name=file_name
)
# load a checkpoint to reduce training time
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/ckp_10_demo_sure.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/ckp_10_demo_sure.pth

  0%|          | 0.00/5.14M [00:00<?, ?B/s]
 22%|██▏       | 1.12M/5.14M [00:00<00:00, 11.0MB/s]
 44%|████▍     | 2.25M/5.14M [00:00<00:00, 11.4MB/s]
 66%|██████▌   | 3.38M/5.14M [00:00<00:00, 10.4MB/s]
 88%|████████▊ | 4.50M/5.14M [00:00<00:00, 10.9MB/s]
100%|██████████| 5.14M/5.14M [00:00<00:00, 10.6MB/s]

Train the network

verbose = True  # print training information
wandb_vis = False  # plot curves and images in Weight&Bias

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
)

# Initialize the trainer
trainer = dinv.Trainer(
    model=model,
    physics=physics,
    epochs=epochs,
    scheduler=scheduler,
    losses=loss,
    optimizer=optimizer,
    device=device,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    plot_images=True,
    save_path=str(CKPT_DIR / operation),
    verbose=verbose,
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
    wandb_vis=wandb_vis,
)

# Train the network
model = trainer.train()
  • Ground truth, Measurement, Reconstruction
  • Ground truth, Measurement, Reconstruction
The model has 444737 trainable parameters
Train epoch 0: TotalLoss=0.002, PSNR=24.558
Eval epoch 0: PSNR=23.876

Test the network

trainer.test(test_dataloader)
Ground truth, Measurement, No learning, Reconstruction
Eval epoch 0: PSNR=23.876, PSNR no learning=19.346
Test results:
PSNR no learning: 19.346 +- 1.786
PSNR: 23.876 +- 2.020

{'PSNR no learning': 19.346022415161134, 'PSNR no learning_std': 1.7863402359700375, 'PSNR': 23.876024627685545, 'PSNR_std': 2.0200106671353986}

Total running time of the script: (0 minutes 1.468 seconds)

Gallery generated by Sphinx-Gallery