EquivariantSplittingLoss#

class deepinv.loss.EquivariantSplittingLoss(*, mask_generator=None, consistency_loss=None, prediction_loss=None, eval_n_samples=5, transform=None, eval_transform=None, img_size=None)[source]#

Bases: Loss

Equivariant splitting loss.

Implements the measurement splitting loss proposed by Sechaud et al.[1]. It generalizes the regular deepinv.loss.SplittingLoss by providing an additional measurement consistency term supporting noise-less losses like deepinv.loss.MCLoss, but also noise-aware losses including deepinv.loss.R2RLoss and deepinv.loss.SureGaussianLoss. Moreover, it automatically renders the base reconstructor equivariant using the Reynolds averaging implemented in deepinv.models.EquivariantReconstructor.

The training loss takes the general form:

\[\mathcal{L}_{\mathrm{ES}} (y, A, f) = \mathbb{E}_g \Big\{ \mathbb{E}_{y_1, A_1 \mid y, A T_g} \Big\{ \underbrace{\| A_1 R(y_1, A_1) - A_1 x \|^2}_{\text{Consistency term}} + \underbrace{\| A_2 R(y_1, A_1) - A_2 x \|^2}_{\text{Prediction term}} \Big\} \Big\}\]

where \(R\) denotes the reconstructor, \(A\) the physics operator, \(x\) the ground truth image, \(y\) the measurement, \(T_g\) a group action (e.g., rotations).

The second expectation is taken over the distribution specified by mask_generator of all possible splittings of \(A T_g\), i.e., \(A T_g = [A_1^\top, A_2^\top]^\top\), with the associated measurements denoted as \(y_1\) and \(y_2\).

The main idea behind equivariant splitting is that the more the reconstructor is equivariant to suitable transformations, the better the final performance will be. A general way to make a reconstructor \(\tilde{R}\) equivariant is to add a group averaging step in the reconstructor,

\[R(y, A) = \frac{1}{|\mathcal{G}|}\sum_{g\in \mathcal{G}} T_g \tilde{R}(y, A T_g)\]

which is generally estimated using a Monte Carlo approach at training time. For this reason, EquivariantSplittingLoss takes two different instances of deepinv.transform.Transform as input: one for training transform and one for evaluation eval_transform.

It is also possible to design an equivariant reconstructor without Reynolds averaging, using equivariant layers. In that case, Reynolds averaging can be disabled to avoid its additional computational cost by leaving transform and eval_transform to None.

The training loss consists in two terms, a consistency term where the comparison is performed against \(A_1 x\) and a prediction term where the comparison is performed against \(A_2 x\). Two parameters control the way these two terms are computed: consistency_loss and prediction_loss.

In the absence of noise, the equivariant splitting loss \(\mathcal{L}_{\mathrm{ES}}\) can be computed exactly without having access to ground truth images. Indeed, in that case, \(A_1 x = y_1\) and \(A_2 x = y_2\). Setting consistency_loss and prediction_loss to deepinv.loss.MCLoss(metric=deepinv.metric.MSE()) allows to compute the loss this way.

In the presence of noise, as long as the splitting scheme is chosen so that the resulting noise components are independent, the prediction term can be estimated without bias using deepinv.loss.MCLoss(metric=deepinv.metric.MSE()) for prediction_loss. This is notably the case for typical splitting schemes, e.g., deepinv.physics.generator.BernoulliSplittingMaskGenerator when the noise is pixel-wise independent, e.g., deepinv.physics.GaussianNoise.

The consistency term should be set to one of the self-supervised denoising losses listed in Self-Supervised Learning, e.g., deepinv.loss.R2RLoss or deepinv.loss.SureGaussianLoss if the noise distribution is known exactly. If the noise parameters are unknown, UNSURE can be used instead, i.e., deepinv.loss.SureGaussianLoss with the option unsure enabled, and if the noise distribution is unknown altogether, the consistency term can be estimated using the Noise2x family of losses.

At training time, a single splitting is performed for each sample in the batch, however, at evaluation time, the reconstructions are averaged over multiple splittings as specified by eval_n_samples.

Parameters:
  • mask_generator (PhysicsGenerator, None) – the generator specifying the distribution of splittings. Defaults to a deepinv.physics.generator.BernoulliSplittingMaskGenerator with image size specified by img_size.

  • consistency_loss (Loss, None) – the loss used to compute the consistency term. Defaults to a deepinv.loss.MCLoss.

  • prediction_loss (Loss, None) – the loss used to compute the prediction term. Defaults to a deepinv.loss.MCLoss.

  • transform (Transform, None) – transformations to be used in training mode for Reynolds averaging (optional).

  • eval_transform (Transform, None) – transformations to be used in evaluation mode for Reynolds averaging. It can be used to have true Reynolds averaging at evaluation time and efficient Monte Carlo estimation at training time. If left unspecified, the value of transform is used at evaluation time as well.

  • img_size (tuple[int, ...]) – the image size for the fallback splitting scheme (optional). It is only used if mask_generator is not specified.


Example:

>>> import torch
>>> import deepinv as dinv
>>> physics = dinv.physics.Inpainting(img_size=(1, 8, 8), mask=0.5)
>>> model = dinv.models.RAM(pretrained=True)
>>> mask_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator(
...     img_size=(1, 8, 8),
...     split_ratio=0.9,
...     pixelwise=True,
... )
>>> train_transform = dinv.transform.Rotate(
...     n_trans=1, multiples=90, positive=True
... ) * dinv.transform.Reflect(n_trans=1, dim=[-1])
>>> eval_transform = dinv.transform.Rotate(
...     n_trans=4, multiples=90, positive=True
... ) * dinv.transform.Reflect(n_trans=2, dim=[-1])
>>> loss = dinv.loss.EquivariantSplittingLoss(
...     mask_generator=mask_generator,
...     consistency_loss=dinv.loss.MCLoss(metric=dinv.metric.MSE()),
...     prediction_loss=dinv.loss.MCLoss(metric=dinv.metric.MSE()),
...     transform=train_transform,
...     eval_transform=eval_transform,
...     eval_n_samples=5,
... )
>>> eq_model = loss.adapt_model(model) # turn into equiv. reconstructor
>>> x = torch.ones((1, 1, 8, 8))
>>> y = physics(x)
>>> x_net = eq_model(y, physics, update_parameters=True)
>>> l = loss(x_net, y, physics, eq_model)
>>> print(l.item() > 0)
True


References:

adapt_model(model)[source]#

Adapt the reconstructor for equivariant splitting.

It wraps the input reconstructor in a splitting model and optionally in a deepinv.models.EquivariantReconstructor if requested.

Parameters:

model (Reconstructor) – the reconstructor to adapt.

Returns:

the adapted reconstructor.

forward(x_net, y, physics, model, **kwargs)[source]#

Compute the equivariant splitting loss.

Parameters:
Returns:

(torch.Tensor) the loss value.

property name: bool#

The name of the loss function. This attribute is deprecated in favor of the class name and it will be removed in a future version.

Examples using EquivariantSplittingLoss:#

Self-supervised learning with Equivariant Splitting

Self-supervised learning with Equivariant Splitting