R2RLoss
- class deepinv.loss.R2RLoss(metric: Metric | Module = MSELoss(), sigma=0.1, alpha=0.5, eval_n_samples=5)[source]
Bases:
Loss
Recorrupted-to-Recorrupted (R2R) Loss
This loss can be used for unsupervised image denoising with unorganized noisy images.
The loss is designed for the noise model:
\[y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).\]The loss is computed as:
\[\| y^- - AR(y^+) \|_2^2 \quad \text{s.t.} \quad y^+ = y + \alpha z, \quad y^- = y - z / \alpha\]where \(R\) is the trainable network, \(A\) is the forward operator, \(y\) is the noisy measurement, \(z\) is the additional Gaussian noise of standard deviation \(\sigma\), and \(\alpha\) is a scaling factor.
This loss is statistically equivalent to the supervised loss function defined on noisy/clean image pairs according to authors in https://ieeexplore.ieee.org/document/9577798
Warning
The model should be adapted before training using the method
adapt_model()
to include the additional noise at the input.Note
\(\sigma\) should be chosen equal or close to \(\sigma\) to obtain the best performance.
Note
To obtain the best test performance, the trained model should be averaged at test time over multiple realizations of the added noise, i.e. \(\hat{x} = \frac{1}{N}\sum_{i=1}^N R(y+\alpha z_i)\) where \(N>1\). This can be achieved using
adapt_model()
.- Parameters:
metric (Metric, torch.nn.Module) – metric for calculating loss, defaults to MSE.
sigma (float) – standard deviation of the Gaussian noise used for the perturbation.
alpha (float) – scaling factor of the perturbation.
eval_n_samples (int) – number of samples used for the Monte Carlo approximation.
- Example:
>>> import torch >>> import deepinv as dinv >>> sigma = 0.1 >>> physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma)) >>> model = dinv.models.MedianFilter() >>> loss = dinv.loss.R2RLoss(sigma=sigma, eval_n_samples=2) >>> model = loss.adapt_model(model) # important step! >>> x = torch.ones((1, 1, 8, 8)) >>> y = physics(x) >>> x_net = model(y, physics, update_parameters=True) # save extra noise in forward pass >>> l = loss(x_net, y, physics, model) >>> print(l.item() > 0) True
- adapt_model(model: Module) R2RModel [source]
Adds noise to model input.
This method modifies a reconstruction model \(R\) to include the splitting mechanism at the input:
\[\hat{R}(y) = \frac{1}{N}\sum_{i=1}^N R(y+\alpha z_i)\]where \(N\geq 1\) are the number of samples used for the Monte Carlo approximation. During training (i.e. when
model.train()
), we use only one sample, i.e. \(N=1\) for computational efficiency, whereas at test time, we use multiple samples for better performance.- Parameters:
model (torch.nn.Module) – Reconstruction model.
sigma (float) – standard deviation of the Gaussian noise used for the perturbation.
alpha (float) – scaling factor of the perturbation.
- Returns:
(torch.nn.Module) Modified model.
- forward(x_net, y, physics, model, **kwargs)[source]
Computes the R2R Loss.
- Parameters:
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction model.
- Returns:
(torch.Tensor) R2R loss.