EquivariantSplittingLoss#
- class deepinv.loss.EquivariantSplittingLoss(*, mask_generator=None, consistency_loss=None, prediction_loss=None, eval_n_samples=5, transform=None, eval_transform=None, img_size=None)[source]#
Bases:
LossEquivariant splitting loss.
Implements the measurement splitting loss proposed by Sechaud et al.[1]. It generalizes the regular
deepinv.loss.SplittingLossby providing an additional measurement consistency term supporting noise-less losses likedeepinv.loss.MCLoss, but also noise-aware losses includingdeepinv.loss.R2RLossanddeepinv.loss.SureGaussianLoss. Moreover, it automatically renders the base reconstructor equivariant using the Reynolds averaging implemented indeepinv.models.EquivariantReconstructor.The training loss takes the general form:
\[\mathcal{L}_{\mathrm{ES}} (y, A, f) = \mathbb{E}_g \Big\{ \mathbb{E}_{y_1, A_1 \mid y, A T_g} \Big\{ \underbrace{\| A_1 R(y_1, A_1) - A_1 x \|^2}_{\text{Consistency term}} + \underbrace{\| A_2 R(y_1, A_1) - A_2 x \|^2}_{\text{Prediction term}} \Big\} \Big\}\]where \(R\) denotes the reconstructor, \(A\) the physics operator, \(x\) the ground truth image, \(y\) the measurement, \(T_g\) a group action (e.g., rotations).
The second expectation is taken over the distribution specified by
mask_generatorof all possible splittings of \(A T_g\), i.e., \(A T_g = [A_1^\top, A_2^\top]^\top\), with the associated measurements denoted as \(y_1\) and \(y_2\).The main idea behind equivariant splitting is that the more the reconstructor is equivariant to suitable transformations, the better the final performance will be. A general way to make a reconstructor \(\tilde{R}\) equivariant is to add a group averaging step in the reconstructor,
\[R(y, A) = \frac{1}{|\mathcal{G}|}\sum_{g\in \mathcal{G}} T_g \tilde{R}(y, A T_g)\]which is generally estimated using a Monte Carlo approach at training time. For this reason,
EquivariantSplittingLosstakes two different instances ofdeepinv.transform.Transformas input: one for trainingtransformand one for evaluationeval_transform.It is also possible to design an equivariant reconstructor without Reynolds averaging, using equivariant layers. In that case, Reynolds averaging can be disabled to avoid its additional computational cost by leaving
transformandeval_transformtoNone.The training loss consists in two terms, a consistency term where the comparison is performed against \(A_1 x\) and a prediction term where the comparison is performed against \(A_2 x\). Two parameters control the way these two terms are computed:
consistency_lossandprediction_loss.In the absence of noise, the equivariant splitting loss \(\mathcal{L}_{\mathrm{ES}}\) can be computed exactly without having access to ground truth images. Indeed, in that case, \(A_1 x = y_1\) and \(A_2 x = y_2\). Setting
consistency_lossandprediction_losstodeepinv.loss.MCLoss(metric=deepinv.metric.MSE())allows to compute the loss this way.In the presence of noise, as long as the splitting scheme is chosen so that the resulting noise components are independent, the prediction term can be estimated without bias using
deepinv.loss.MCLoss(metric=deepinv.metric.MSE())forprediction_loss. This is notably the case for typical splitting schemes, e.g.,deepinv.physics.generator.BernoulliSplittingMaskGeneratorwhen the noise is pixel-wise independent, e.g.,deepinv.physics.GaussianNoise.The consistency term should be set to one of the self-supervised denoising losses listed in Self-Supervised Learning, e.g.,
deepinv.loss.R2RLossordeepinv.loss.SureGaussianLossif the noise distribution is known exactly. If the noise parameters are unknown, UNSURE can be used instead, i.e.,deepinv.loss.SureGaussianLosswith the optionunsureenabled, and if the noise distribution is unknown altogether, the consistency term can be estimated using the Noise2x family of losses.At training time, a single splitting is performed for each sample in the batch, however, at evaluation time, the reconstructions are averaged over multiple splittings as specified by
eval_n_samples.- Parameters:
mask_generator (PhysicsGenerator, None) – the generator specifying the distribution of splittings. Defaults to a
deepinv.physics.generator.BernoulliSplittingMaskGeneratorwith image size specified byimg_size.consistency_loss (Loss, None) – the loss used to compute the consistency term. Defaults to a
deepinv.loss.MCLoss.prediction_loss (Loss, None) – the loss used to compute the prediction term. Defaults to a
deepinv.loss.MCLoss.transform (Transform, None) – transformations to be used in training mode for Reynolds averaging (optional).
eval_transform (Transform, None) – transformations to be used in evaluation mode for Reynolds averaging. It can be used to have true Reynolds averaging at evaluation time and efficient Monte Carlo estimation at training time. If left unspecified, the value of
transformis used at evaluation time as well.img_size (tuple[int, ...]) – the image size for the fallback splitting scheme (optional). It is only used if
mask_generatoris not specified.
- Example:
>>> import torch >>> import deepinv as dinv >>> physics = dinv.physics.Inpainting(img_size=(1, 8, 8), mask=0.5) >>> model = dinv.models.RAM(pretrained=True) >>> mask_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator( ... img_size=(1, 8, 8), ... split_ratio=0.9, ... pixelwise=True, ... ) >>> train_transform = dinv.transform.Rotate( ... n_trans=1, multiples=90, positive=True ... ) * dinv.transform.Reflect(n_trans=1, dim=[-1]) >>> eval_transform = dinv.transform.Rotate( ... n_trans=4, multiples=90, positive=True ... ) * dinv.transform.Reflect(n_trans=2, dim=[-1]) >>> loss = dinv.loss.EquivariantSplittingLoss( ... mask_generator=mask_generator, ... consistency_loss=dinv.loss.MCLoss(metric=dinv.metric.MSE()), ... prediction_loss=dinv.loss.MCLoss(metric=dinv.metric.MSE()), ... transform=train_transform, ... eval_transform=eval_transform, ... eval_n_samples=5, ... ) >>> eq_model = loss.adapt_model(model) # turn into equiv. reconstructor >>> x = torch.ones((1, 1, 8, 8)) >>> y = physics(x) >>> x_net = eq_model(y, physics, update_parameters=True) >>> l = loss(x_net, y, physics, eq_model) >>> print(l.item() > 0) True
- References:
- adapt_model(model)[source]#
Adapt the reconstructor for equivariant splitting.
It wraps the input reconstructor in a splitting model and optionally in a
deepinv.models.EquivariantReconstructorif requested.- Parameters:
model (Reconstructor) – the reconstructor to adapt.
- Returns:
the adapted reconstructor.
- forward(x_net, y, physics, model, **kwargs)[source]#
Compute the equivariant splitting loss.
- Parameters:
x_net (torch.Tensor) – the reconstructed image.
y (torch.Tensor) – the measurement.
physics (Physics) – the physics operator.
model (Reconstructor) – the reconstruction function.
- Returns:
(
torch.Tensor) the loss value.
Examples using EquivariantSplittingLoss:#
Self-supervised learning with Equivariant Splitting