SurePGLoss

class deepinv.loss.SurePGLoss(sigma, gain, tau1=0.001, tau2=0.01, second_derivative=False, unsure=False, step_size=(0.0001, 0.0001), momentum=(0.9, 0.9), rng=None)[source]

Bases: Loss

SURE loss for Poisson-Gaussian noise

The loss is designed for the following noise model:

\[y = \gamma z + \epsilon\]

where \(u = A(x)\), \(z \sim \mathcal{P}\left(\frac{u}{\gamma}\right)\), and \(\epsilon \sim \mathcal{N}(0, \sigma^2 I)\).

The loss is computed as

\[\begin{split}& \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y-\sigma^2 +\frac{2}{m\tau_1}(b\odot (\gamma y + \sigma^2 I))^{\top} \left(A\inverse{y+\tau b}-A\inverse{y} \right) \\\\ & +\frac{2\gamma \sigma^2}{m\tau_2^2}c^{\top} \left( A\inverse{y+\tau c} + A\inverse{y-\tau c} - 2A\inverse{y} \right)\end{split}\]

where \(R\) is the trainable network, \(y\) is the noisy measurement vector, \(b\) is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5, \(\tau\) is a small positive number, and \(\odot\) is an elementwise multiplication.

If the measurement data is truly Poisson-Gaussian this loss is an unbiased estimator of the mean squared loss \(\frac{1}{m}\|u-A\inverse{y}\|_2^2\) where \(z\) is the noiseless measurement.

See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.

Warning

The loss can be sensitive to the choice of \(\tau\), which should be proportional to the size of \(y\). The default value of 0.01 is adapted to \(y\) vectors with entries in \([0,1]\).

Note

If the noise levels are unknown, the loss can be adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985, which also learns the noise levels.

Parameters:
  • sigma (float) – Standard deviation of the Gaussian noise.

  • gamma (float) – Gain of the Poisson Noise.

  • tau (float) – Approximation constant for the Monte Carlo approximation of the divergence.

  • tau2 (float) – Approximation constant for the second derivative.

  • second_derivative (bool) – If False, the last term in the loss (approximating the second derivative) is removed to speed up computations, at the cost of a possibly inexact loss. Default True.

  • unsure (bool) – If True, the loss is adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985 where \(\gamma\) and \(\sigma^2\) are also learned (their input value is used as initialization).

  • step_size (tuple[float]) – Step size for the gradient ascent of the noise levels if unsure is True.

  • momentum (tuple[float]) – Momentum for the gradient ascent of the noise levels if unsure is True.

  • rng (torch.Generator) – Optional random number generator. Default is None.

forward(y, x_net, physics, model, **kwargs)[source]

Computes the SURE loss.

Parameters:
Returns:

torch.nn.Tensor loss of size (batch_size,)

Examples using SurePGLoss:

Self-supervised denoising with the UNSURE loss.

Self-supervised denoising with the UNSURE loss.

Self-supervised denoising with the SURE loss.

Self-supervised denoising with the SURE loss.