R2RLoss

class deepinv.loss.R2RLoss(metric: Metric | Module = MSELoss(), sigma=0.1, alpha=0.5, eval_n_samples=5)[source]

Bases: Loss

Recorrupted-to-Recorrupted (R2R) Loss

This loss can be used for unsupervised image denoising with unorganized noisy images.

The loss is designed for the noise model:

\[y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).\]

The loss is computed as:

\[\| y^- - AR(y^+) \|_2^2 \quad \text{s.t.} \quad y^+ = y + \alpha z, \quad y^- = y - z / \alpha\]

where \(R\) is the trainable network, \(A\) is the forward operator, \(y\) is the noisy measurement, \(z\) is the additional Gaussian noise of standard deviation \(\sigma\), and \(\alpha\) is a scaling factor.

This loss is statistically equivalent to the supervised loss function defined on noisy/clean image pairs according to authors in https://ieeexplore.ieee.org/document/9577798

Warning

The model should be adapted before training using the method adapt_model() to include the additional noise at the input.

Note

\(\sigma\) should be chosen equal or close to \(\sigma\) to obtain the best performance.

Note

To obtain the best test performance, the trained model should be averaged at test time over multiple realizations of the added noise, i.e. \(\hat{x} = \frac{1}{N}\sum_{i=1}^N R(y+\alpha z_i)\) where \(N>1\). This can be achieved using adapt_model().

Parameters:
  • metric (Metric, torch.nn.Module) – metric for calculating loss, defaults to MSE.

  • sigma (float) – standard deviation of the Gaussian noise used for the perturbation.

  • alpha (float) – scaling factor of the perturbation.

  • eval_n_samples (int) – number of samples used for the Monte Carlo approximation.


Example:
>>> import torch
>>> import deepinv as dinv
>>> sigma = 0.1
>>> physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
>>> model = dinv.models.MedianFilter()
>>> loss = dinv.loss.R2RLoss(sigma=sigma, eval_n_samples=2)
>>> model = loss.adapt_model(model) # important step!
>>> x = torch.ones((1, 1, 8, 8))
>>> y = physics(x)
>>> x_net = model(y, physics, update_parameters=True) # save extra noise in forward pass
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True
adapt_model(model: Module) R2RModel[source]

Adds noise to model input.

This method modifies a reconstruction model \(R\) to include the splitting mechanism at the input:

\[\hat{R}(y) = \frac{1}{N}\sum_{i=1}^N R(y+\alpha z_i)\]

where \(N\geq 1\) are the number of samples used for the Monte Carlo approximation. During training (i.e. when model.train()), we use only one sample, i.e. \(N=1\) for computational efficiency, whereas at test time, we use multiple samples for better performance.

Parameters:
  • model (torch.nn.Module) – Reconstruction model.

  • sigma (float) – standard deviation of the Gaussian noise used for the perturbation.

  • alpha (float) – scaling factor of the perturbation.

Returns:

(torch.nn.Module) Modified model.

forward(x_net, y, physics, model, **kwargs)[source]

Computes the R2R Loss.

Parameters:
Returns:

(torch.Tensor) R2R loss.