BaseLossScheduler#
- class deepinv.loss.BaseLossScheduler(*loss: Loss, generator: Generator | None = None)[source]#
Bases:
Loss
Base class for loss schedulers.
Wraps a list of losses, and each time forward is called, some of them are selected based on a defined schedule.
- Parameters:
*loss (Loss) – loss or multiple losses to be scheduled.
generator (Generator) – torch random number generator, defaults to None
- adapt_model(model: Module, **kwargs)[source]#
Adapt model using all wrapped losses.
Some loss functions require the model forward call to be adapted before the forward pass.
- Parameters:
model (Module) – reconstruction model
- forward(x_net: Tensor | None = None, x: Tensor | None = None, y: Tensor | None = None, physics: Physics | None = None, model: Module | None = None, epoch: int | None = None, **kwargs)[source]#
Loss forward pass.
When called, subselect losses based on defined schedule to be used at this pass, and apply to inputs.