MultiplicativeSplittingMaskGenerator#

class deepinv.physics.generator.MultiplicativeSplittingMaskGenerator(tensor_size, split_generator, *args, **kwargs)[source]#

Bases: BernoulliSplittingMaskGenerator

Multiplicative splitting mask generator.

Randomly generates binary masks using the given physics_generator, and multiplies the input_mask (i.e. mask that is used to create accelerated measurements).

Given an acceleration mask \(M\) sampled from a known distribution, this generator provides masks \(M'=M_1 \circ M\) with \(M_1\) sampled from split_generator, which is typically the same distribution as \(M\).

See also

deepinv.loss.mri.WeightedSplittingLoss

K-weighted splitting loss proposed in Millard and Chiew, where this splitting mask generator is used for self-supervised learning.


Examples:
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator
>>> physics_generator = GaussianMaskGenerator((1, 128, 128), acceleration=4)
>>> orig_mask = physics_generator.step(batch_size=2)["mask"]
>>> split_generator = GaussianMaskGenerator((1, 128, 128), acceleration=2)
>>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator)
>>> mask_generator.step(batch_size=2, input_mask=orig_mask)["mask"].shape
torch.Size([2, 1, 128, 128])
Parameters: