StepLossScheduler#

class deepinv.loss.StepLossScheduler(*loss: Loss, epoch_thresh: int = 0)[source]#

Bases: BaseLossScheduler

Activate losses at specified epoch.

The scheduler wraps a list of losses. When epoch is <= threshold, this returns 0. Otherwise, it returns the sum of the losses.

Example:

>>> import torch
>>> from deepinv.loss import StepLossScheduler
>>> from deepinv.loss.metric import SSIM
>>> l = StepLossScheduler(SSIM(train_loss=True)) # Use SSIM only after epoch 10
>>> x_net = torch.zeros(1, 1, 12, 12)
>>> x = torch.ones(1, 1, 12, 12)
>>> l(x=x, x_net=x_net, epoch=0)
tensor(0., requires_grad=True)
>>> l(x=x, x_net=x_net, epoch=11)
tensor([0.9999])
Parameters:
  • *loss (Loss) – loss or multiple losses to be scheduled.

  • epoch_thresh (int) – threshold above which the losses are used.

schedule(epoch) List[Loss][source]#

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.