InterleavedEpochLossScheduler

class deepinv.loss.InterleavedEpochLossScheduler(*loss: Loss, generator: Generator | None = None)[source]

Bases: BaseLossScheduler

Schedule losses sequentially epoch-by-epoch.

The scheduler wraps a list of losses. Each epoch, the next loss is selected in order and used for the forward pass for that epoch.

Example:

>>> import torch
>>> from deepinv.loss import InterleavedEpochLossScheduler, SupLoss
>>> from deepinv.loss.metric import SSIM
>>> l = InterleavedEpochLossScheduler(SupLoss(), SSIM(train_loss=True)) # Choose alternating between Sup and SSIM
>>> x_net = x = torch.tensor([0., 0., 0.])
>>> l(x=x, x_net=x_net, epoch=0)
tensor(0.)
Parameters:

*loss (Loss) – loss or multiple losses to be scheduled.

schedule(epoch) List[Loss][source]

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.