BaseLossScheduler#

class deepinv.loss.BaseLossScheduler(*loss: Loss, generator: Generator | None = None)[source]#

Bases: Loss

Base class for loss schedulers.

Wraps a list of losses, and each time forward is called, some of them are selected based on a defined schedule.

Parameters:
  • *loss (Loss) – loss or multiple losses to be scheduled.

  • generator (Generator) – torch random number generator, defaults to None

adapt_model(model: Module, **kwargs)[source]#

Adapt model using all wrapped losses.

Some loss functions require the model forward call to be adapted before the forward pass.

Parameters:

model (Module) – reconstruction model

forward(x_net: Tensor | None = None, x: Tensor | None = None, y: Tensor | None = None, physics: Physics | None = None, model: Module | None = None, epoch: int | None = None, **kwargs)[source]#

Loss forward pass.

When called, subselect losses based on defined schedule to be used at this pass, and apply to inputs.

Parameters:
  • x_net (Tensor) – model output

  • x (Tensor) – ground truth

  • y (Tensor) – measurement

  • physics (Physics) – measurement operator

  • model (Module) – reconstruction model

  • epoch (int) – current epoch

schedule(epoch: int) List[Loss][source]#

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.