GeneratorMixture#
- class deepinv.physics.generator.GeneratorMixture(generators, probs, use_batch_sampling=True, device='cpu', rng=None, verbose=False)[source]#
Bases:
PhysicsGeneratorBase class for mixing multiple
physics generators.The mixture randomly selects a subset of batch elements to be generated by each generator according to the probabilities given in the constructor.
- Parameters:
generators (list[PhysicsGenerator]) – the generators instantiated from
deepinv.physics.generator.PhysicsGenerator.probs (list[float]) – the probability of each generator to be used at each step
device (str) – device on which the generator is located, defaults to “cpu”
rng (torch.Generator) – a pseudorandom random number generator for the parameter generation. If
None, a generator will be created on the specified device with a random seed.verbose (bool) – whether to print warnings about the batch-compatibility of the generators, defaults to False.
- Param:
bool use_batch_sampling: whether to sample a different generator for each element in the batch. This is only possible if all generators in the mixture produce parameters with the same keys and shapes. If not, a single generator will be sampled per batch. Defaults to
True.
- Examples:
Mixing two types of blur
>>> from deepinv.physics.generator import MotionBlurGenerator, DiffractionBlurGenerator >>> from deepinv.physics.generator import GeneratorMixture >>> _ = torch.manual_seed(0) >>> g1 = MotionBlurGenerator(psf_size = (3, 3), num_channels = 1) >>> g2 = DiffractionBlurGenerator(psf_size = (3, 3), num_channels = 1) >>> generator = GeneratorMixture([g1, g2], [0.5, 0.5]) >>> params_dict = generator.step(batch_size=1) >>> print(params_dict.keys()) dict_keys(['filter'])