BaseLossScheduler#

class deepinv.loss.BaseLossScheduler(*loss, generator=None)[source]#

Bases: Loss

Base class for loss schedulers.

Wraps a list of losses, and each time forward is called, some of them are selected based on a defined schedule.

Parameters:
  • *loss (Loss) – loss or multiple losses to be scheduled.

  • generator (Generator) – torch random number generator, defaults to None

adapt_model(model, **kwargs)[source]#

Adapt model using all wrapped losses.

Some loss functions require the model forward call to be adapted before the forward pass.

Parameters:

model (torch.nn.Module) – reconstruction model

forward(x_net=None, x=None, y=None, physics=None, model=None, epoch=None, **kwargs)[source]#

Loss forward pass.

When called, subselect losses based on defined schedule to be used at this pass, and apply to inputs.

Parameters:
schedule(epoch)[source]#

Return selected losses based on defined schedule, optionally based on current epoch.

Parameters:

epoch (int) – current epoch number

Return list[Loss]:

selected (sub)list of losses to be used this time.

Return type:

List[Loss]