MultiplicativeSplittingMaskGenerator#
- class deepinv.physics.generator.MultiplicativeSplittingMaskGenerator(img_size, split_generator, *args, **kwargs)[source]#
Bases:
BernoulliSplittingMaskGenerator
Multiplicative splitting mask generator.
Randomly generates binary masks using the given
physics_generator
, and multiplies theinput_mask
(i.e. mask that is used to create accelerated measurements).Given an acceleration mask \(M\) sampled from a known distribution, this generator provides masks \(M'=M_1 \circ M\) with \(M_1\) sampled from
split_generator
, which is typically the same distribution as \(M\).See also
deepinv.loss.mri.WeightedSplittingLoss
K-weighted splitting loss proposed in Millard and Chiew[1], where this splitting mask generator is used for self-supervised learning.
- Examples:
>>> from deepinv.physics.generator import GaussianMaskGenerator, MultiplicativeSplittingMaskGenerator >>> physics_generator = GaussianMaskGenerator((1, 128, 128), acceleration=4) >>> orig_mask = physics_generator.step(batch_size=2)["mask"] >>> split_generator = GaussianMaskGenerator((1, 128, 128), acceleration=2) >>> mask_generator = MultiplicativeSplittingMaskGenerator((1, 128, 128), split_generator) >>> mask_generator.step(batch_size=2, input_mask=orig_mask)["mask"].shape torch.Size([2, 1, 128, 128])
- Parameters:
img_size (tuple[int]) – size of the tensor to be masked without batch dimension e.g. of shape (C, H, W) or (C, T, H, W). Note this can be overriden on-the-fly by passing in
img_size
orinput_mask
arguments tostep
.split_generator (deepinv.physics.generator.BaseMaskGenerator) – mask generator used for multiplicative splitting
device (str, torch.device) – device where the tensor is stored (default: ‘cpu’).
rng (torch.Generator) – torch random number generator.
dtype (torch.dtype) – the data type of the generated parameters
- References:
- batch_step(input_mask=None, img_size=None)[source]#
Create one batch of splitting mask.
- Parameters:
input_mask (torch.Tensor, None) – optional mask to be split. If
None
, all pixels are considered. If notNone
, only pixels wheremask==1
are considered. Batch dimension should not be included in shape.img_size (tuple) – if not
None
, generate masks of this 2D image shape and overrideimg_size
attribute, must be of form(H, W)
.
- Returns:
mask without batch dimension of shape specified either by
img_size
,input_mask
, or class attributeimg_size
.- Return type: