Self-supervised denoising with the UNSURE loss.

This example shows you how to train a denoiser network in a fully self-supervised way, i.e., using noisy images with unknown noise level only via the UNSURE loss, which is introduced in https://arxiv.org/abs/2409.01985.

The UNSURE optimization problem for Gaussian denoising with unknown noise level is defined as:

\[\min_{R} \max_{\sigma^2} \frac{1}{m}\|y-\inverse{y}\|_2^2 +\frac{2\sigma^2}{m\tau}b^{\top} \left(\inverse{y+\tau b}-\inverse{y}\right)\]

where \(R\) is the trainable network, \(y\) is the noisy image with \(m\) pixels, \(b\sim \mathcal{N}(0,1)\) is a Gaussian random variable, \(\tau\) is a small positive number, and \(\odot\) is an elementwise multiplication.

import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from pathlib import Path
from torchvision import transforms, datasets

Setup paths for data loading and results.

BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load base image datasets

In this example, we use the MNIST dataset as the base image dataset.

operation = "denoising"
train_dataset_name = "MNIST"

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(
    root="../datasets/", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root="../datasets/", train=False, transform=transform, download=True
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0.00/9.91M [00:00<?, ?B/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 144MB/s]
Extracting ../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 5.53MB/s]
Extracting ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 43.7MB/s]
Extracting ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.4MB/s]
Extracting ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw

Generate a dataset of noisy images

We generate a dataset of noisy images corrupted by Gaussian noise.

Note

We use a subset of the whole training set to reduce the computational load of the example. We recommend to use the whole set by setting n_images_max=None to get the best results.

true_sigma = 0.1

# defined physics
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma=true_sigma))

# Use parallel dataloader if using a GPU to fasten training,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0

n_images_max = (
    100 if torch.cuda.is_available() else 5
)  # number of images used for training

measurement_dir = DATA_DIR / train_dataset_name / operation
deepinv_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    test_datapoints=n_images_max,
    num_workers=num_workers,
    dataset_filename="demo_sure",
)

train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)
Dataset has been saved in measurements/MNIST/denoising

Set up the denoiser network

We use a simple U-Net architecture with 2 scales as the denoiser network.

model = dinv.models.ArtifactRemoval(
    dinv.models.UNet(in_channels=1, out_channels=1, scales=2).to(device)
)

Set up the training parameters

We set deepinv.loss.SureGaussianLoss as the training loss with the unsure=True option. The optimization with respect to the noise level is done by stochastic gradient descent with momentum inside the loss class, so it is seamlessly integrated into the training process.

Note

There are (UN)SURE losses for various noise distributions. See also deepinv.loss.SurePGLoss for mixed Poisson-Gaussian noise.

Note

We train for only 10 epochs to reduce the computational load of the example. We recommend to train for more epochs to get the best results.

epochs = 10  # choose training epochs
learning_rate = 5e-4
batch_size = 32 if torch.cuda.is_available() else 1

sigma_init = 0.05  # initial guess for the noise level
step_size = 1e-4  # step size for the optimization of the noise level
momentum = 0.9  # momentum for the optimization of the noise level

# choose self-supervised training loss
loss = dinv.loss.SureGaussianLoss(
    sigma=sigma_init, unsure=True, step_size=step_size, momentum=momentum
)

# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)

print(f"INIT. noise level {loss.sigma2.sqrt().item():.3f}")
INIT. noise level 0.050

Train the network

We train the network using the deepinv.Trainer class.

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True
)

# Initialize the trainer
trainer = dinv.Trainer(
    model=model,
    physics=physics,
    epochs=epochs,
    losses=loss,
    optimizer=optimizer,
    device=device,
    train_dataloader=train_dataloader,
    plot_images=False,
    save_path=str(CKPT_DIR / operation),
    verbose=True,  # print training information
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
)

# Train the network
model = trainer.train()
The model has 444737 trainable parameters
Train epoch 0: TotalLoss=0.082, PSNR=11.174
Train epoch 1: TotalLoss=0.029, PSNR=14.826
Train epoch 2: TotalLoss=0.015, PSNR=17.465
Train epoch 3: TotalLoss=0.01, PSNR=19.055
Train epoch 4: TotalLoss=0.007, PSNR=20.112
Train epoch 5: TotalLoss=0.006, PSNR=21.228
Train epoch 6: TotalLoss=0.005, PSNR=21.775
Train epoch 7: TotalLoss=0.005, PSNR=22.558
Train epoch 8: TotalLoss=0.004, PSNR=23.066
Train epoch 9: TotalLoss=0.004, PSNR=23.374

Check learned noise level

We can verify the learned noise level by checking the estimated noise level from the loss function.

est_sigma = loss.sigma2.sqrt().item()

print(f"LEARNED noise level {est_sigma:.3f}")
print(f"Estimation error noise level {abs(est_sigma-true_sigma):.3f}")
LEARNED noise level 0.097
Estimation error noise level 0.003

Test the network

Ground truth, Measurement, No learning, Reconstruction
Eval epoch 0: PSNR=23.689, PSNR no learning=19.981
Test results:
PSNR no learning: 19.981 +- 0.108
PSNR: 23.689 +- 0.881

{'PSNR no learning': 19.980844879150389, 'PSNR no learning_std': 0.10819350836126303, 'PSNR': 23.688781356811525, 'PSNR_std': 0.88140769703778066}

Total running time of the script: (0 minutes 3.227 seconds)

Gallery generated by Sphinx-Gallery