WeightedSplittingLoss#
- class deepinv.loss.mri.WeightedSplittingLoss(mask_generator, physics_generator, metric=None)[source]#
Bases:
SplittingLoss
K-Weighted Splitting Loss
Implements the K-weighted Noisier2Noise-SSDU loss from Millard and Chiew[1]. The loss is designed for problems where measurements are observed as \(y_i=M_iAx\), where \(M_i\) is a random mask, such as in
MRI
whereA
is the Fourier transform. The loss is defined as follows, using notation fromdeepinv.loss.SplittingLoss
:\[\frac{m}{m_2}\| (1-\mathbf{K})^{-1/2} (y_2 - A_2 \inversef{y_1}{A_1})\|^2\]where \(\mathbf{K}\) is derived from the probability density function (pdf) of the (original) acceleration mask and (further) splitting mask:
\[\mathbf{K}=(\mathbb{I}_n-\tilde{\mathbf{P}}\mathbf{P})^{-1}(\mathbb{I}_n-\mathbf{P})\]and \(\mathbf{P}=\mathbb{E}[\mathbf{M}_i],\tilde{\mathbf{P}}=\mathbb{E}[\mathbf{M}_1]\) i.e. the average imaging mask and splitting mask, respectively. At inference, the original whole measurement \(y\) is used as input.
Note
To match the original paper, the loss should be used with the splitting mask
deepinv.physics.generator.MultiplicativeSplittingMaskGenerator
where the input additional subsampling mask should be the same type as that used to generate the measurements.Note the method was originally proposed for accelerated MRI problems (where the measurements are generated via a mask generator).
Note also that we assume that all masks are 1D mask in the image width dimension repeated in all other dimensions.
If the input data varies in shape, the loss will dynamically recalculate the weight. However, this will be slower every time the weight must be recalculated.
- Parameters:
mask_generator (deepinv.physics.generator.BernoulliSplittingMaskGenerator) – splitting mask generator for further subsampling.
physics_generator (deepinv.physics.generator.BaseMaskGenerator) – original mask generator used to generate the measurements.
metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
- Example:
>>> import torch >>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator >>> from deepinv.loss.mri import WeightedSplittingLoss >>> physics_generator = GaussianMaskGenerator((128, 128), acceleration=4) >>> split_generator = GaussianMaskGenerator((128, 128), acceleration=2) >>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator) >>> loss = WeightedSplittingLoss(mask_generator, physics_generator)
- References:
- class WeightedMetric(mask_generator, physics_generator, pixel_metric)[source]#
Bases:
Module
Wraps metric to apply weight on inputs
Note
mask_generator
andphysics_generator
are only used to regenerate the weight in the case that y has different shapes during training.- Parameters:
torch.Tensor – loss weight.
mask_generator (deepinv.physics.generator.BernoulliSplittingMaskGenerator) – splitting mask generator for further subsampling.
physics_generator (deepinv.physics.generator.BaseMaskGenerator) – original mask generator used to generate the measurements.
pixel_metric (Metric, torch.nn.Module) – loss metric.
expand (bool) – whether expand weight to input dims
- static compute_weight(mask_generator, physics_generator, eps=1e-9, img_size=None)[source]#
Compute weight for K-weighted splitting loss where K is a diagonal matrix of shape
(H, W)
, and returned weight is a tensor of shape(1, W)
.Estimates the 1D PDFs of the mask generators empirically.
- Parameters:
mask_generator (deepinv.physics.generator.BernoulliSplittingMaskGenerator) – splitting mask generator for further subsampling.
physics_generator (deepinv.physics.generator.BaseMaskGenerator) – original mask generator used to generate the measurements.
eps (float) – small value to avoid division by zero.
img_size (tuple) – desired mask shape
(H, W)
. IfNone
, use default provided inphysics_generator
andmask_generator
.