MultiplicativeSplittingMaskGenerator#
- class deepinv.physics.generator.MultiplicativeSplittingMaskGenerator(tensor_size, split_generator, *args, **kwargs)[source]#
Bases:
BernoulliSplittingMaskGenerator
Multiplicative splitting mask generator.
Randomly generates binary masks using the given
physics_generator
, and multiplies theinput_mask
(i.e. mask that is used to create accelerated measurements).Given an acceleration mask \(M\) sampled from a known distribution, this generator provides masks \(M'=M_1 \circ M\) with \(M_1\) sampled from
split_generator
, which is typically the same distribution as \(M\).See also
deepinv.loss.mri.WeightedSplittingLoss
K-weighted splitting loss proposed in Millard and Chiew, where this splitting mask generator is used for self-supervised learning.
- Examples:
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator >>> physics_generator = GaussianMaskGenerator((1, 128, 128), acceleration=4) >>> orig_mask = physics_generator.step(batch_size=2)["mask"] >>> split_generator = GaussianMaskGenerator((1, 128, 128), acceleration=2) >>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator) >>> mask_generator.step(batch_size=2, input_mask=orig_mask)["mask"].shape torch.Size([2, 1, 128, 128])
- Parameters:
tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension e.g. of shape (C, H, W) or (C, T, H, W)
split_generator (deepinv.physics.generator.BaseMaskGenerator) – mask generator used for multiplicative splitting
device (str, torch.device) – device where the tensor is stored (default: ‘cpu’).
rng (torch.Generator) – torch random number generator.
dtype (torch.dtype) – the data type of the generated parameters