ScoreLoss

class deepinv.loss.ScoreLoss(noise_model=None, total_batches=1000, delta=(0.001, 0.1))[source]

Bases: Loss

Learns score of noise distribution.

Approximates the score of the measurement distribution \(S(y)\approx \nabla \log p(y)\) https://proceedings.neurips.cc/paper_files/paper/2021/file/077b83af57538aa183971a2fe0971ec1-Paper.pdf.

The score loss is defined as

\[\| \epsilon + \sigma S(y+ \sigma \epsilon) \|^2\]

where \(y\) is the noisy measurement, \(S\) is the model approximating the score of the noisy measurement distribution \(\nabla \log p(y)\), \(\epsilon\) is sampled from \(N(0,I)\) and \(\sigma\) is sampled from \(N(0,I\delta^2)\) with \(\delta\) annealed during training from a maximum value to a minimum value.

At test/evaluation time, the method uses Tweedie’s formula to estimate the score, which depends on the noise model used:

  • Gaussian noise: \(R(y) = y + \sigma^2 S(y)\)

  • Poisson noise: \(R(y) = y + \gamma y S(y)\)

  • Gamma noise: \(R(y) = \frac{\ell y}{(\ell-1)-y S(y)}\)

Warning

The user should provide a backbone model \(S\) to adapt_model() which returns the full reconstruction network \(R\), which is mandatory to compute the loss properly.

Warning

This class uses the inference formula for the Poisson noise case which differs from the one proposed in Noise2Score.

Note

This class does not support general inverse problems, it is only designed for denoising problems.

Parameters:


Example:
>>> import torch
>>> import deepinv as dinv
>>> sigma = 0.1
>>> physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
>>> model = dinv.models.DnCNN(depth=2, pretrained=None)
>>> loss = dinv.loss.ScoreLoss(total_batches=1, delta=(0.001, 0.1))
>>> model = loss.adapt_model(model) # important step!
>>> x = torch.ones((1, 3, 5, 5))
>>> y = physics(x)
>>> x_net = model(y, physics, update_parameters=True) # save score loss in forward
>>> l = loss(model)
>>> print(l.item() > 0)
True
adapt_model(model, **kwargs)[source]

Transforms score backbone net S() into R() for training and evaluation.

Parameters:

model (torch.nn.Module) – Backbone model approximating the score.

Returns:

(torch.nn.Module) Adapted reconstruction model.

forward(model, **kwargs)[source]

Computes the Score Loss.

Parameters:
Returns:

(torch.Tensor) Score loss.