StepLossScheduler#
- class deepinv.loss.StepLossScheduler(*loss: Loss, epoch_thresh: int = 0)[source]#
Bases:
BaseLossScheduler
Activate losses at specified epoch.
The scheduler wraps a list of losses. When epoch is <= threshold, this returns 0. Otherwise, it returns the sum of the losses.
- Example:
>>> import torch >>> from deepinv.loss import StepLossScheduler >>> from deepinv.loss.metric import SSIM >>> l = StepLossScheduler(SSIM(train_loss=True)) # Use SSIM only after epoch 10 >>> x_net = torch.zeros(1, 1, 12, 12) >>> x = torch.ones(1, 1, 12, 12) >>> l(x=x, x_net=x_net, epoch=0) tensor(0., requires_grad=True) >>> l(x=x, x_net=x_net, epoch=11) tensor([0.9999])
- Parameters: