WeightedSplittingLoss#
- class deepinv.loss.mri.WeightedSplittingLoss(mask_generator, physics_generator, eps=1e-9, metric=torch.nn.MSELoss())[source]#
Bases:
SplittingLoss
K-Weighted Splitting Loss
Implements the K-weighted Noisier2Noise-SSDU loss from Millard and Chiew. The loss is designed for problems where measurements are observed as \(y_i=M_iAx\), where \(M_i\) is a random mask, such as in
MRI
whereA
is the Fourier transform. The loss is defined as follows, using notation fromdeepinv.loss.SplittingLoss
:\[\frac{m}{m_2}\| (1-\mathbf{K})^{-1/2} (y_2 - A_2 \inversef{y_1}{A_1})\|^2\]where \(\mathbf{K}\) is derived from the probability density function (pdf) of the (original) acceleration mask and (further) splitting mask:
\[\mathbf{K}=(\mathbb{I}_n-\tilde{\mathbf{P}}\mathbf{P})^{-1}(\mathbb{I}_n-\mathbf{P})\]and \(\mathbf{P}=\mathbb{E}[\mathbf{M}_i],\tilde{\mathbf{P}}=\mathbb{E}[\mathbf{M}_1]\) i.e. the average imaging mask and splitting mask, respectively. At inference, the original whole measurement \(y\) is used as input.
Note
To match the original paper, the loss should be used with the splitting mask
deepinv.physics.generator.MultiplicativeSplittingMaskGenerator
where the input additional subsampling mask should be the same type as that used to generate the measurements.Note the method was originally proposed for accelerated MRI problems (where the measurements are generated via a mask generator).
Note also that we assume that all masks are 1D mask in the image width dimension repeated in all other dimensions.
- Parameters:
mask_generator (deepinv.physics.generator.BernoulliSplittingMaskGenerator) – splitting mask generator for further subsampling.
physics_generator (deepinv.physics.generator.BaseMaskGenerator) – original mask generator used to generate the measurements.
eps (float) – small value to avoid division by zero.
metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
- Example:
>>> import torch >>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator >>> from deepinv.loss.mri import WeightedSplittingLoss >>> physics_generator = GaussianMaskGenerator((128, 128), acceleration=4) >>> split_generator = GaussianMaskGenerator((128, 128), acceleration=2) >>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator) >>> loss = WeightedSplittingLoss(mask_generator, physics_generator)