RandomLossScheduler

class deepinv.loss.RandomLossScheduler(*loss: Loss, generator: Generator | None = None)[source]

Bases: BaseLossScheduler

Schedule losses at random.

The scheduler wraps a list of losses. Each time this is called, one loss is selected at random and used for the forward pass.

Example:

>>> import torch
>>> from deepinv.loss import RandomLossScheduler, SupLoss
>>> from deepinv.loss.metric import SSIM
>>> l = RandomLossScheduler(SupLoss(), SSIM(train_loss=True)) # Choose randomly between Sup and SSIM
>>> x_net = x = torch.tensor([0., 0., 0.])
>>> l(x=x, x_net=x_net)
tensor(0.)
Parameters:
  • *loss (Loss) – loss or multiple losses to be scheduled.

  • generator (Generator) – torch random number generator, defaults to None

schedule(epoch) List[Loss][source]

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.