Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial.
Self-supervised denoising with the UNSURE loss.#
This example shows you how to train a denoiser network in a fully self-supervised way, i.e., using noisy images with unknown noise level only via the UNSURE loss, which is introduced by Tachella et al.[1].
The UNSURE optimization problem for Gaussian denoising with unknown noise level is defined as:
where \(R\) is the trainable network, \(y\) is the noisy image with \(m\) pixels, \(b\sim \mathcal{N}(0,1)\) is a Gaussian random variable, \(\tau\) is a small positive number, and \(\odot\) is an elementwise multiplication.
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import deepinv as dinv
from deepinv.utils import get_data_home
Setup paths for data loading and results.#
BASE_DIR = Path(".")
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"
ORIGINAL_DATA_DIR = get_data_home()
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load base image datasets#
In this example, we use the MNIST dataset as the base image dataset.
operation = "denoising"
train_dataset_name = "MNIST"
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
root=ORIGINAL_DATA_DIR, train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
root=ORIGINAL_DATA_DIR, train=False, transform=transform, download=True
)
Generate a dataset of noisy images#
We generate a dataset of noisy images corrupted by Gaussian noise.
Note
We use a subset of the whole training set to reduce the computational load of the example.
We recommend to use the whole set by setting n_images_max=None to get the best results.
true_sigma = 0.1
# defined physics
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma=true_sigma))
# Use parallel dataloader if using a GPU to speed up training,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0
n_images_max = (
100 if torch.cuda.is_available() else 5
) # number of images used for training
measurement_dir = DATA_DIR / train_dataset_name / operation
deepinv_datasets_path = dinv.datasets.generate_dataset(
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
device=device,
save_dir=measurement_dir,
train_datapoints=n_images_max,
test_datapoints=n_images_max,
num_workers=num_workers,
dataset_filename="demo_sure",
)
train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)
/home/runner/work/deepinv/deepinv/deepinv/datasets/datagenerator.py:270: UserWarning: Dataset measurements/MNIST/denoising/demo_sure0.h5 already exists, this will close and overwrite the previous dataset.
warn(
Dataset has been saved at measurements/MNIST/denoising/demo_sure0.h5
Set up the denoiser network#
We use a simple U-Net architecture with 2 scales as the denoiser network.
model = dinv.models.ArtifactRemoval(
dinv.models.UNet(in_channels=1, out_channels=1, scales=2).to(device)
)
Set up the training parameters#
We set deepinv.loss.SureGaussianLoss as the training loss with the unsure=True option.
The optimization with respect to the noise level is done by stochastic gradient descent with momentum
inside the loss class, so it is seamlessly integrated into the training process.
Note
There are (UN)SURE losses for various noise distributions. See also deepinv.loss.SurePGLoss for mixed Poisson-Gaussian noise.
Note
We train for only 10 epochs to reduce the computational load of the example. We recommend to train for more epochs to get the best results.
epochs = 10 # choose training epochs
learning_rate = 5e-4
batch_size = 32 if torch.cuda.is_available() else 1
sigma_init = 0.05 # initial guess for the noise level
step_size = 1e-4 # step size for the optimization of the noise level
momentum = 0.9 # momentum for the optimization of the noise level
# choose self-supervised training loss
loss = dinv.loss.SureGaussianLoss(
sigma=sigma_init, unsure=True, step_size=step_size, momentum=momentum
)
# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
print(f"INIT. noise level {loss.sigma2.sqrt().item():.3f}")
INIT. noise level 0.050
Train the network#
We train the network using the deepinv.Trainer class.
To simulate a realistic self-supervised learning scenario, we do not use any supervised metrics for training, such as PSNR or SSIM, which require clean ground truth images.
Tip
We can use the same self-supervised loss for evaluation (without updating the noise level, which is equivalent to SURE with the estimated noise level),
as it does not require clean images,
to monitor the training process (e.g. for early stopping). This is done automatically when metrics=None and early_stop>0 in the trainer.
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
# Initialize the trainer
trainer = dinv.Trainer(
model=model,
physics=physics,
epochs=epochs,
losses=loss,
compute_eval_losses=True, # use self-supervised loss for evaluation
early_stop_on_losses=True, # stop using self-supervised eval loss
metrics=None, # no supervised metrics
early_stop=2, # early stop using the self-supervised loss on the test set
optimizer=optimizer,
device=device,
train_dataloader=train_dataloader,
eval_dataloader=train_dataloader,
plot_images=False,
save_path=str(CKPT_DIR / operation),
verbose=True, # print training information
show_progress_bar=False, # disable progress bar for better vis in sphinx gallery.
)
# Train the network
model = trainer.train()
The model has 444737 trainable parameters
Train epoch 0: TotalLoss=0.082
Eval epoch 0: TotalLoss=0.036
Best model saved at epoch 1
Train epoch 1: TotalLoss=0.026
Eval epoch 1: TotalLoss=0.015
Best model saved at epoch 2
Train epoch 2: TotalLoss=0.013
Eval epoch 2: TotalLoss=0.01
Best model saved at epoch 3
Train epoch 3: TotalLoss=0.009
Eval epoch 3: TotalLoss=0.009
Best model saved at epoch 4
Train epoch 4: TotalLoss=0.007
Eval epoch 4: TotalLoss=0.006
Best model saved at epoch 5
Train epoch 5: TotalLoss=0.006
Eval epoch 5: TotalLoss=0.005
Best model saved at epoch 6
Train epoch 6: TotalLoss=0.004
Eval epoch 6: TotalLoss=0.003
Best model saved at epoch 7
Train epoch 7: TotalLoss=0.002
Eval epoch 7: TotalLoss=0.002
Best model saved at epoch 8
Train epoch 8: TotalLoss=0.001
Eval epoch 8: TotalLoss=0.001
Best model saved at epoch 9
Train epoch 9: TotalLoss=0.001
Eval epoch 9: TotalLoss=0.0
Best model saved at epoch 10
Check learned noise level#
We can verify the learned noise level by checking the estimated noise level from the loss function.
est_sigma = loss.sigma2.sqrt().item()
print(f"LEARNED noise level {est_sigma:.3f}")
print(f"Estimation error noise level {abs(est_sigma-true_sigma):.3f}")
LEARNED noise level 0.110
Estimation error noise level 0.010
Test the network#
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
)
trainer.plot_images = True
trainer.test(test_dataloader=test_dataloader, metrics=dinv.metric.PSNR())

Eval epoch 0: TotalLoss=0.004, PSNR=22.459, PSNR no learning=19.981
Test results:
PSNR no learning: 19.981 +- 0.108
PSNR: 22.459 +- 0.243
{'PSNR no learning': 19.98084487915039, 'PSNR no learning_std': 0.10819350836126303, 'PSNR': 22.458982849121092, 'PSNR_std': 0.24259799500724102}
- References:
Total running time of the script: (0 minutes 2.474 seconds)