WeightedSplittingLoss#

class deepinv.loss.mri.WeightedSplittingLoss(mask_generator, physics_generator, metric=None)[source]#

Bases: SplittingLoss

K-Weighted Splitting Loss

Implements the K-weighted Noisier2Noise-SSDU loss from Millard and Chiew[1]. The loss is designed for problems where measurements are observed as \(y_i=M_iAx\), where \(M_i\) is a random mask, such as in MRI where A is the Fourier transform. The loss is defined as follows, using notation from deepinv.loss.SplittingLoss:

\[\frac{m}{m_2}\| (1-\mathbf{K})^{-1/2} (y_2 - A_2 \inversef{y_1}{A_1})\|^2\]

where \(\mathbf{K}\) is derived from the probability density function (pdf) of the (original) acceleration mask and (further) splitting mask:

\[\mathbf{K}=(\mathbb{I}_n-\tilde{\mathbf{P}}\mathbf{P})^{-1}(\mathbb{I}_n-\mathbf{P})\]

and \(\mathbf{P}=\mathbb{E}[\mathbf{M}_i],\tilde{\mathbf{P}}=\mathbb{E}[\mathbf{M}_1]\) i.e. the average imaging mask and splitting mask, respectively. At inference, the original whole measurement \(y\) is used as input.

Note

To match the original paper, the loss should be used with the splitting mask deepinv.physics.generator.MultiplicativeSplittingMaskGenerator where the input additional subsampling mask should be the same type as that used to generate the measurements.

Note the method was originally proposed for accelerated MRI problems (where the measurements are generated via a mask generator).

Note also that we assume that all masks are 1D mask in the image width dimension repeated in all other dimensions.

If the input data varies in shape, the loss will dynamically recalculate the weight. However, this will be slower every time the weight must be recalculated.

Parameters:


Example:

>>> import torch
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator
>>> from deepinv.loss.mri import WeightedSplittingLoss
>>> physics_generator = GaussianMaskGenerator((128, 128), acceleration=4)
>>> split_generator = GaussianMaskGenerator((128, 128), acceleration=2)
>>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator)
>>> loss = WeightedSplittingLoss(mask_generator, physics_generator)


References:

class WeightedMetric(mask_generator, physics_generator, pixel_metric)[source]#

Bases: Module

Wraps metric to apply weight on inputs

Note mask_generator and physics_generator are only used to regenerate the weight in the case that y has different shapes during training.

Parameters:
forward(y1, y2)[source]#

Weighted metric forward pass.

static compute_weight(mask_generator, physics_generator, eps=1e-9, img_size=None)[source]#

Compute weight for K-weighted splitting loss where K is a diagonal matrix of shape (H, W), and returned weight is a tensor of shape (1, W).

Estimates the 1D PDFs of the mask generators empirically.

Parameters: