RandomLossScheduler#
- class deepinv.loss.RandomLossScheduler(*loss: Loss, generator: Generator | None = None)[source]#
Bases:
BaseLossScheduler
Schedule losses at random.
The scheduler wraps a list of losses. Each time this is called, one loss is selected at random and used for the forward pass.
- Example:
>>> import torch >>> from deepinv.loss import RandomLossScheduler, SupLoss >>> from deepinv.loss.metric import SSIM >>> l = RandomLossScheduler(SupLoss(), SSIM(train_loss=True)) # Choose randomly between Sup and SSIM >>> x_net = x = torch.tensor([0., 0., 0.]) >>> l(x=x, x_net=x_net) tensor(0.)
- Parameters:
*loss (Loss) – loss or multiple losses to be scheduled.
generator (Generator) – torch random number generator, defaults to None