SurePGLoss
- class deepinv.loss.SurePGLoss(sigma, gain, tau1=0.001, tau2=0.01, second_derivative=False, unsure=False, step_size=(0.0001, 0.0001), momentum=(0.9, 0.9), rng=None)[source]
Bases:
Loss
SURE loss for Poisson-Gaussian noise
The loss is designed for the following noise model:
\[y = \gamma z + \epsilon\]where \(u = A(x)\), \(z \sim \mathcal{P}\left(\frac{u}{\gamma}\right)\), and \(\epsilon \sim \mathcal{N}(0, \sigma^2 I)\).
The loss is computed as
\[\begin{split}& \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y-\sigma^2 +\frac{2}{m\tau_1}(b\odot (\gamma y + \sigma^2 I))^{\top} \left(A\inverse{y+\tau b}-A\inverse{y} \right) \\\\ & +\frac{2\gamma \sigma^2}{m\tau_2^2}c^{\top} \left( A\inverse{y+\tau c} + A\inverse{y-\tau c} - 2A\inverse{y} \right)\end{split}\]where \(R\) is the trainable network, \(y\) is the noisy measurement vector, \(b\) is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5, \(\tau\) is a small positive number, and \(\odot\) is an elementwise multiplication.
If the measurement data is truly Poisson-Gaussian this loss is an unbiased estimator of the mean squared loss \(\frac{1}{m}\|u-A\inverse{y}\|_2^2\) where \(z\) is the noiseless measurement.
See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
Warning
The loss can be sensitive to the choice of \(\tau\), which should be proportional to the size of \(y\). The default value of 0.01 is adapted to \(y\) vectors with entries in \([0,1]\).
Note
If the noise levels are unknown, the loss can be adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985, which also learns the noise levels.
- Parameters:
sigma (float) – Standard deviation of the Gaussian noise.
gamma (float) – Gain of the Poisson Noise.
tau (float) – Approximation constant for the Monte Carlo approximation of the divergence.
tau2 (float) – Approximation constant for the second derivative.
second_derivative (bool) – If
False
, the last term in the loss (approximating the second derivative) is removed to speed up computations, at the cost of a possibly inexact loss. DefaultTrue
.unsure (bool) – If
True
, the loss is adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985 where \(\gamma\) and \(\sigma^2\) are also learned (their input value is used as initialization).step_size (tuple[float]) – Step size for the gradient ascent of the noise levels if unsure is
True
.momentum (tuple[float]) – Momentum for the gradient ascent of the noise levels if unsure is
True
.rng (torch.Generator) – Optional random number generator. Default is None.
- forward(y, x_net, physics, model, **kwargs)[source]
Computes the SURE loss.
- Parameters:
y (torch.Tensor) – measurements.
x_net (torch.Tensor) – reconstructed image \(\inverse{y}\).
physics (deepinv.physics.Physics) – Forward operator associated with the measurements
f (torch.nn.Module) – Reconstruction network
- Returns:
torch.nn.Tensor loss of size (batch_size,)
Examples using SurePGLoss
:
Self-supervised denoising with the UNSURE loss.
Self-supervised denoising with the SURE loss.