BaseLossScheduler#
- class deepinv.loss.BaseLossScheduler(*loss, generator=None)[source]#
Bases:
Loss
Base class for loss schedulers.
Wraps a list of losses, and each time forward is called, some of them are selected based on a defined schedule.
- Parameters:
*loss (Loss) – loss or multiple losses to be scheduled.
generator (Generator) – torch random number generator, defaults to None
- adapt_model(model, **kwargs)[source]#
Adapt model using all wrapped losses.
Some loss functions require the model forward call to be adapted before the forward pass.
- Parameters:
model (torch.nn.Module) – reconstruction model
- forward(x_net=None, x=None, y=None, physics=None, model=None, epoch=None, **kwargs)[source]#
Loss forward pass.
When called, subselect losses based on defined schedule to be used at this pass, and apply to inputs.
- Parameters:
x_net (torch.Tensor) – model output
x (torch.Tensor) – ground truth
y (torch.Tensor) – measurement
physics (Physics) – measurement operator
model (torch.nn.Module) – reconstruction model
epoch (int) – current epoch