SplittingLoss

class deepinv.loss.SplittingLoss(metric: Metric | Module = MSELoss(), split_ratio: float = 0.9, mask_generator: PhysicsGenerator | None = None, eval_n_samples=5, eval_split_input=True, eval_split_output=False, pixelwise=True)[source]

Bases: Loss

Measurement splitting loss.

Implements measurement splitting loss from Yaman et al. (SSDU) for MRI, Hendriksen et al. (Noise2Inverse) for CT, Acar et al. dynamic MRI. Also see deepinv.loss.Artifact2ArtifactLoss, deepinv.loss.Phase2PhaseLoss for similar.

Splits the measurement and forward operator \(\forw{}\) (of size \(m\)) into two smaller pairs \((y_1,A_1)\) (of size \(m_1\)) and \((y_2,A_2)\) (of size \(m_2\)) , to compute the self-supervised loss:

\[\frac{m}{m_2}\| y_2 - A_2 \inversef{y_1}{A_1}\|^2\]

where \(R\) is the trainable network, \(A_1 = M_1 \forw{}, A_2 = M_2 \forw{}\), and \(M_i\) are randomly generated masks (i.e. diagonal matrices) such that \(M_1+M_2=\mathbb{I}_m\).

See Self-supervised learning with measurement splitting for usage example.

Note

If the forward operator has its own subsampling mask \(M_{\forw{}}\), e.g. deepinv.physics.Inpainting or deepinv.physics.MRI, the splitting masks will be subsets of the physics’ mask such that \(M_1+M_2=M_{\forw{}}\)

This loss was used in SSDU for MRI in Yaman et al. Self-supervised learning of physics-guided reconstruction neural networks without fully sampled reference data

By default, the error is computed using the MSE metric, however any appropriate metric can be used.

Warning

The model should be adapted before training using the method adapt_model() to include the splitting mechanism at the input.

Note

To obtain the best test performance, the trained model should be averaged at test time over multiple realizations of the splitting, i.e. \(\hat{x} = \frac{1}{N}\sum_{i=1}^N \inversef{y_1^{(i)}}{A_1^{(i)}}\). To disable this, set eval_n_samples=1.

Note

To disable measurement splitting (and use the full input) at evaluation time, set eval_split_input=True. This is done in SSDU.

Parameters:
  • metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.

  • split_ratio (float) – splitting ratio, should be between 0 and 1. The size of \(y_1\) increases with the splitting ratio. Ignored if mask_generator passed.

  • mask_generator (deepinv.physics.generator.PhysicsGenerator, None) – function to generate the mask. If None, the deepinv.physics.generator.BernoulliSplittingMaskGenerator is used, with the parameters split_ratio and pixelwise.

  • eval_n_samples (int) – Number of samples used for averaging at evaluation time. Must be greater than 0.

  • eval_split_input (bool) – if True, perform input measurement splitting during evaluation. If False, use full measurement at eval (no MC samples are performed and eval_split_output will have no effect)

  • eval_split_output (bool) – at evaluation time, pass the output through the output mask too. i.e. \((\sum_{j=1}^N M_2^{(j)})^{-1} \sum_{i=1}^N M_2^{(i)} \inversef{y_1^{(i)}}{A_1^{(i)}}\). Only valid when \(y\) is same domain (and dimension) as \(x\). Although better results may be observed on small datasets, more samples must be used for bigger images. Defaults to False.

  • pixelwise (bool) – if True, create pixelwise splitting masks i.e. zero all channels simultaneously. Ignored if mask_generator passed.


Example:

>>> import torch
>>> import deepinv as dinv
>>> physics = dinv.physics.Inpainting(tensor_size=(1, 8, 8), mask=0.5)
>>> model = dinv.models.MedianFilter()
>>> loss = dinv.loss.SplittingLoss(split_ratio=0.9, eval_n_samples=2)
>>> model = loss.adapt_model(model) # important step!
>>> x = torch.ones((1, 1, 8, 8))
>>> y = physics(x)
>>> x_net = model(y, physics, update_parameters=True) # save random mask in forward pass
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True
adapt_model(model: Module, eval_n_samples=None) SplittingModel[source]

Apply random splitting to input.

This method modifies a reconstruction model \(R\) to include the splitting mechanism at the input:

\[\hat{R}(y, A) = \frac{1}{N}\sum_{i=1}^N \inversef{y_1^{(i)}}{A_1^{(i)}}\]

where \(N\geq 1\) is the number of Monte Carlo samples, and \(y_1^{(i)}\) and \(A_1^{(i)}\) are obtained by randomly splitting the measurements \(y\) and operator \(A\). During training (i.e. when model.train()), we use only one sample, i.e. \(N=1\) for computational efficiency, whereas at test time, we use multiple samples for better performance. For other parameters that control how splitting is applied, see the class parameters.

Parameters:
  • model (torch.nn.Module) – Reconstruction model.

  • eval_n_samples (int) – deprecated. Pass eval_n_samples at class initialisation instead.

Returns:

(torch.nn.Module) Model modified for evaluation.

forward(x_net, y, physics, model, **kwargs)[source]

Computes the measurement splitting loss

Parameters:
Returns:

(torch.Tensor) loss.

static split(mask: Tensor, y: Tensor, physics: Physics | None = None)[source]

Perform splitting given mask

Parameters:

Examples using SplittingLoss:

Self-supervised learning with measurement splitting

Self-supervised learning with measurement splitting

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised MRI reconstruction with Artifact2Artifact