ScoreLoss#
- class deepinv.loss.ScoreLoss(noise_model=None, total_batches=1000, delta=(0.001, 0.1))[source]#
Bases:
Loss
Learns score of noise distribution.
Approximates the score of the measurement distribution \(S(y)\approx \nabla \log p(y)\) https://proceedings.neurips.cc/paper_files/paper/2021/file/077b83af57538aa183971a2fe0971ec1-Paper.pdf.
The score loss is defined as
\[\| \epsilon + \sigma S(y+ \sigma \epsilon) \|^2\]where \(y\) is the noisy measurement, \(S\) is the model approximating the score of the noisy measurement distribution \(\nabla \log p(y)\), \(\epsilon\) is sampled from \(N(0,I)\) and \(\sigma\) is sampled from \(N(0,I\delta^2)\) with \(\delta\) annealed during training from a maximum value to a minimum value.
At test/evaluation time, the method uses Tweedie’s formula to estimate the score, which depends on the noise model used:
Gaussian noise: \(R(y) = y + \sigma^2 S(y)\)
Poisson noise: \(R(y) = y + \gamma y S(y)\)
Gamma noise: \(R(y) = \frac{\ell y}{(\ell-1)-y S(y)}\)
Warning
The user should provide a backbone model \(S\) to
adapt_model()
which returns the full reconstruction network \(R\), which is mandatory to compute the loss properly.Warning
This class uses the inference formula for the Poisson noise case which differs from the one proposed in Noise2Score.
Note
This class does not support general inverse problems, it is only designed for denoising problems.
- Parameters:
noise_model (None, torch.nn.Module) – Noise distribution corrupting the measurements (see the physics docs). Options are
deepinv.physics.GaussianNoise
,deepinv.physics.PoissonNoise
,deepinv.physics.GammaNoise
anddeepinv.physics.UniformGaussianNoise
. By default, it uses the noise model associated with the physics operator provided in the forward method.total_batches (int) – Total number of training batches (epochs * number of batches per epoch).
delta (tuple) – Tuple of two floats representing the minimum and maximum noise level, which are annealed during training.
- Example:
>>> import torch >>> import deepinv as dinv >>> sigma = 0.1 >>> physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma)) >>> model = dinv.models.DnCNN(depth=2, pretrained=None) >>> loss = dinv.loss.ScoreLoss(total_batches=1, delta=(0.001, 0.1)) >>> model = loss.adapt_model(model) # important step! >>> x = torch.ones((1, 3, 5, 5)) >>> y = physics(x) >>> x_net = model(y, physics, update_parameters=True) # save score loss in forward >>> l = loss(model) >>> print(l.item() > 0) True
- adapt_model(model, **kwargs)[source]#
Transforms score backbone net
S()
intoR()
for training and evaluation.- Parameters:
model (torch.nn.Module) – Backbone model approximating the score.
- Returns:
(torch.nn.Module) Adapted reconstruction model.
- forward(model, **kwargs)[source]#
Computes the Score Loss.
- Parameters:
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction model.
- Returns:
(torch.Tensor) Score loss.