MultiplicativeSplittingMaskGenerator#

class deepinv.physics.generator.MultiplicativeSplittingMaskGenerator(img_size, split_generator, *args, **kwargs)[source]#

Bases: BernoulliSplittingMaskGenerator

Multiplicative splitting mask generator.

Randomly generates binary masks using the given physics_generator, and multiplies the input_mask (i.e. mask that is used to create accelerated measurements).

Given an acceleration mask \(M\) sampled from a known distribution, this generator provides masks \(M'=M_1 \circ M\) with \(M_1\) sampled from split_generator, which is typically the same distribution as \(M\).

See also

deepinv.loss.mri.WeightedSplittingLoss

K-weighted splitting loss proposed in Millard and Chiew[1], where this splitting mask generator is used for self-supervised learning.


Examples:
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator
>>> physics_generator = GaussianMaskGenerator((1, 128, 128), acceleration=4)
>>> orig_mask = physics_generator.step(batch_size=2)["mask"]
>>> split_generator = GaussianMaskGenerator((1, 128, 128), acceleration=2)
>>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator)
>>> mask_generator.step(batch_size=2, input_mask=orig_mask)["mask"].shape
torch.Size([2, 1, 128, 128])
Parameters:
  • img_size (tuple[int]) – size of the tensor to be masked without batch dimension e.g. of shape (C, H, W) or (C, T, H, W). Note this can be overriden on-the-fly by passing in img_size or input_mask arguments to step.

  • split_generator (deepinv.physics.generator.BaseMaskGenerator) – mask generator used for multiplicative splitting

  • device (str, torch.device) – device where the tensor is stored (default: ‘cpu’).

  • rng (torch.Generator) – torch random number generator.

  • dtype (torch.dtype) – the data type of the generated parameters


References:

batch_step(input_mask=None, img_size=None)[source]#

Create one batch of splitting mask.

Parameters:
  • input_mask (torch.Tensor, None) – optional mask to be split. If None, all pixels are considered. If not None, only pixels where mask==1 are considered. Batch dimension should not be included in shape.

  • img_size (tuple) – if not None, generate masks of this 2D image shape and override img_size attribute, must be of form (H, W).

Returns:

mask without batch dimension of shape specified either by img_size, input_mask, or class attribute img_size.

Return type:

dict