MultiplicativeSplittingMaskGenerator#
- class deepinv.physics.generator.MultiplicativeSplittingMaskGenerator(img_size, split_generator, *args, **kwargs)[source]#
Bases:
BernoulliSplittingMaskGeneratorMultiplicative splitting mask generator.
Randomly generates binary masks using the given
physics_generator, and multiplies theinput_mask(i.e. mask that is used to create accelerated measurements).Given an acceleration mask \(M\) sampled from a known distribution, this generator provides masks \(M'=M_1 \circ M\) with \(M_1\) sampled from
split_generator, which is typically the same distribution as \(M\).See also
deepinv.loss.mri.WeightedSplittingLossK-weighted splitting loss proposed in Millard and Chiew[1], where this splitting mask generator is used for self-supervised learning.
- Examples:
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator >>> physics_generator = GaussianMaskGenerator((1, 128, 128), acceleration=4) >>> orig_mask = physics_generator.step(batch_size=2)["mask"] >>> split_generator = GaussianMaskGenerator((1, 128, 128), acceleration=2) >>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator) >>> mask_generator.step(batch_size=2, input_mask=orig_mask)["mask"].shape torch.Size([2, 1, 128, 128])
- Parameters:
img_size (tuple[int]) – size of the tensor to be masked without batch dimension e.g. of shape (C, H, W) or (C, T, H, W). Note this can be overriden on-the-fly by passing in
img_sizeorinput_maskarguments tostep.split_generator (deepinv.physics.generator.BaseMaskGenerator) – mask generator used for multiplicative splitting
device (str, torch.device) – device where the tensor is stored (default: ‘cpu’).
rng (torch.Generator) – torch random number generator.
dtype (torch.dtype) – the data type of the generated parameters
- References:
- batch_step(input_mask=None, img_size=None)[source]#
Create one batch of splitting mask.
- Parameters:
input_mask (torch.Tensor, None) – optional mask to be split. If
None, all pixels are considered. If notNone, only pixels wheremask==1are considered. Batch dimension should not be included in shape.img_size (tuple) – if not
None, generate masks of this 2D image shape and overrideimg_sizeattribute, must be of form(H, W).
- Returns:
mask without batch dimension of shape specified either by
img_size,input_mask, or class attributeimg_size.- Return type: