SureGaussianLoss
- class deepinv.loss.SureGaussianLoss(sigma, tau=0.01, B=<function SureGaussianLoss.<lambda>>, unsure=False, step_size=0.0001, momentum=0.9, rng: ~torch._C.Generator | None = None)[source]
Bases:
Loss
SURE loss for Gaussian noise
The loss is designed for the following noise model:
\[y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).\]The loss is computed as
\[\frac{1}{m}\|B(y - A\inverse{y})\|_2^2 -\sigma^2 +\frac{2\sigma^2}{m\tau}b^{\top} B^{\top} \left(A\inverse{y+\tau b_i} - A\inverse{y}\right)\]where \(R\) is the trainable network, \(A\) is the forward operator, \(y\) is the noisy measurement vector of size \(m\), \(A\) is the forward operator, \(B\) is an optional linear mapping which should be approximately \(A^{\dagger}\) (or any stable approximation), \(b\sim\mathcal{N}(0,I)\) and \(\tau\geq 0\) is a hyperparameter controlling the Monte Carlo approximation of the divergence.
This loss approximates the divergence of \(A\inverse{y}\) (in the original SURE loss) using the Monte Carlo approximation in https://ieeexplore.ieee.org/abstract/document/4099398/
If the measurement data is truly Gaussian with standard deviation \(\sigma\), this loss is an unbiased estimator of the mean squared loss \(\frac{1}{m}\|u-A\inverse{y}\|_2^2\) where \(z\) is the noiseless measurement.
Warning
The loss can be sensitive to the choice of \(\tau\), which should be proportional to the size of \(y\). The default value of 0.01 is adapted to \(y\) vectors with entries in \([0,1]\).
Note
If the noise level is unknown, the loss can be adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985, which also learns the noise level.
- Parameters:
sigma (float) – Standard deviation of the Gaussian noise.
tau (float) – Approximation constant for the Monte Carlo approximation of the divergence.
B (Callable, str) – Optional linear metric \(B\), which can be used to improve the performance of the loss. If ‘A_dagger’, the pseudo-inverse of the forward operator is used. Otherwise the metric should be a linear operator that approximates the pseudo-inverse of the forward operator such as
deepinv.physics.LinearPhysics.prox_l2()
with large \(\gamma\). By default, the identity is used.unsure (bool) – If
True
, the loss is adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985 where the noise level \(\sigma\) is also learned (the input value is used as initialization).step_size (float) – Step size for the gradient ascent of the noise level if unsure is
True
.momentum (float) – Momentum for the gradient ascent of the noise level if unsure is
True
.rng (torch.Generator) – Optional random number generator. Default is None.
- forward(y, x_net, physics, model, **kwargs)[source]
Computes the SURE Loss.
- Parameters:
y (torch.Tensor) – Measurements.
x_net (torch.Tensor) – reconstructed image \(\inverse{y}\).
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction network.
- Returns:
torch.nn.Tensor loss of size (batch_size,)
Examples using SureGaussianLoss
:
Self-supervised denoising with the UNSURE loss.
Self-supervised denoising with the SURE loss.