BernoulliSplittingMaskGenerator
- class deepinv.physics.generator.BernoulliSplittingMaskGenerator(tensor_size: Tuple[int], split_ratio: float, pixelwise: bool = True, device: device = device(type='cpu'), dtype: dtype = torch.float32, rng: Generator | None = None, *args, **kwargs)[source]
Bases:
PhysicsGenerator
Base generator for splitting/inpainting masks.
Generates binary masks with an approximate given split ratio, according to a Bernoulli distribution. Can be used either for generating random inpainting masks for
deepinv.physics.Inpainting
, or random splitting masks fordeepinv.loss.SplittingLoss
.Optional pass in input_mask to subsample this mask given the split ratio. For mask ratio to be almost exactly as specified, use this option with a flat mask of ones as input.
- Examples:
Generate random mask
>>> from deepinv.physics.generator import BernoulliSplittingMaskGenerator >>> gen = BernoulliSplittingMaskGenerator((1, 3, 3), split_ratio=0.6) >>> gen.step(batch_size=2)["mask"].shape torch.Size([2, 1, 3, 3])
Generate splitting mask from given input_mask
>>> from deepinv.physics.generator import BernoulliSplittingMaskGenerator >>> from deepinv.physics import Inpainting >>> physics = Inpainting((1, 3, 3), 0.9) >>> gen = BernoulliSplittingMaskGenerator((1, 3, 3), split_ratio=0.6) >>> gen.step(batch_size=2, input_mask=physics.mask)["mask"].shape torch.Size([2, 1, 3, 3])
- Parameters:
tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension e.g. of shape (C, H, W) or (C, M) or (M,)
split_ratio (float) – ratio of values to be kept.
pixelwise (bool) – Apply the mask in a pixelwise fashion, i.e., zero all channels in a given pixel simultaneously.
device (str, torch.device) – device where the tensor is stored (default: ‘cpu’).
rng (torch.Generator) – torch random number generator.
- batch_step(input_mask: Tensor | None = None) dict [source]
Create one batch of splitting mask.
- Parameters:
input_mask (torch.Tensor, None) – optional mask to be split. If
None
, all pixels are considered. If notNone
, only pixels wheremask==1
are considered. Batch dimension should not be included in shape.
- check_pixelwise(input_mask=None) bool [source]
Check if pixelwise can be used given input_mask dimensions and tensor_size dimensions
- step(batch_size=1, input_mask: Tensor | None = None, seed: int | None = None, **kwargs) dict [source]
Generate a random mask.
If
input_mask
is None, generates a standard random mask that can be used fordeepinv.physics.Inpainting
. Ifinput_mask
is specified, splits the input mask into subsets given the split ratio.- Parameters:
batch_size (int) – batch_size. If None, no batch dimension is created. If input_mask passed and has its own batch dimension > 1, batch_size is ignored.
input_mask (torch.Tensor, None) – optional mask to be split. If None, all pixels are considered. If not None, only pixels where mask==1 are considered. input_mask shape can optionally include a batch dimension.
seed (int) – the seed for the random number generator.
- Returns:
dictionary with key ‘mask’: tensor of size
(batch_size, *tensor_size)
with values in {0, 1}.- Return type:
Examples using BernoulliSplittingMaskGenerator
:
Self-supervised learning with measurement splitting
Self-supervised MRI reconstruction with Artifact2Artifact