WeightedSplittingLoss#

class deepinv.loss.mri.WeightedSplittingLoss(mask_generator, physics_generator, eps=1e-9, metric=torch.nn.MSELoss())[source]#

Bases: SplittingLoss

K-Weighted Splitting Loss

Implements the K-weighted Noisier2Noise-SSDU loss from Millard and Chiew. The loss is designed for problems where measurements are observed as \(y_i=M_iAx\), where \(M_i\) is a random mask, such as in MRI where A is the Fourier transform. The loss is defined as follows, using notation from deepinv.loss.SplittingLoss:

\[\frac{m}{m_2}\| (1-\mathbf{K})^{-1/2} (y_2 - A_2 \inversef{y_1}{A_1})\|^2\]

where \(\mathbf{K}\) is derived from the probability density function (pdf) of the (original) acceleration mask and (further) splitting mask:

\[\mathbf{K}=(\mathbb{I}_n-\tilde{\mathbf{P}}\mathbf{P})^{-1}(\mathbb{I}_n-\mathbf{P})\]

and \(\mathbf{P}=\mathbb{E}[\mathbf{M}_i],\tilde{\mathbf{P}}=\mathbb{E}[\mathbf{M}_1]\) i.e. the average imaging mask and splitting mask, respectively. At inference, the original whole measurement \(y\) is used as input.

Note

To match the original paper, the loss should be used with the splitting mask deepinv.physics.generator.MultiplicativeSplittingMaskGenerator where the input additional subsampling mask should be the same type as that used to generate the measurements.

Note the method was originally proposed for accelerated MRI problems (where the measurements are generated via a mask generator).

Note also that we assume that all masks are 1D mask in the image width dimension repeated in all other dimensions.

Parameters:


Example:

>>> import torch
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator
>>> from deepinv.loss.mri import WeightedSplittingLoss
>>> physics_generator = GaussianMaskGenerator((128, 128), acceleration=4)
>>> split_generator = GaussianMaskGenerator((128, 128), acceleration=2)
>>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator)
>>> loss = WeightedSplittingLoss(mask_generator, physics_generator)
class WeightedMetric(weight, metric, expand=True)[source]#

Bases: Module

Wraps metric to apply weight on inputs

Parameters:
  • torch.Tensor – loss weight.

  • metric (Metric, torch.nn.Module) – loss metric.

  • expand (bool) – whether expand weight to input dims

forward(y1, y2)[source]#

Weighted metric forward pass.

compute_k(eps=1e-9)[source]#

Compute K for K-weighted splitting loss where K is a diagonal matrix of shape (H, W).

Estimates the 1D PDFs of the mask generators empirically.

Parameters:

eps (float) – small value to avoid division by zero.