InterleavedLossScheduler

class deepinv.loss.InterleavedLossScheduler(*loss: Loss)[source]

Bases: BaseLossScheduler

Schedule losses sequentially one-by-one.

The scheduler wraps a list of losses. Each time this is called, the next loss is selected in order and used for the forward pass.

Example:

>>> import torch
>>> from deepinv.loss import InterleavedLossScheduler, SupLoss
>>> from deepinv.loss.metric import SSIM
>>> l = InterleavedLossScheduler(SupLoss(), SSIM(train_loss=True)) # Choose alternating between Sup and SSIM
>>> x_net = x = torch.tensor([0., 0., 0.])
>>> l(x=x, x_net=x_net)
tensor(0.)
Parameters:

*loss (Loss) – loss or multiple losses to be scheduled.

schedule(epoch) List[Loss][source]

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.