train#
- deepinv.train(model, physics, optimizer, train_dataloader, epochs=100, losses=SupLoss(), eval_dataloader=None, *args, **kwargs)[source]#
Alias function for training a model using
deepinv.Trainer
class.This function creates a Trainer instance and returns the trained model.
Warning
This function is deprecated and will be removed in future versions. Please use
deepinv.Trainer
instead.- Parameters:
model (torch.nn.Module) – Reconstruction network, which can be PnP, unrolled, artifact removal or any other custom reconstruction network.
physics (deepinv.physics.Physics, list[deepinv.physics.Physics]) – Forward operator(s) used by the reconstruction network.
epochs (int) – Number of training epochs. Default is 100.
optimizer (torch.optim.Optimizer) – Torch optimizer for training the network.
train_dataloader (torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Train data loader(s) should provide a a signal x or a tuple of (x, y) signal/measurement pairs.
losses (deepinv.loss.Loss, list[deepinv.loss.Loss]) – Loss or list of losses used for training the model. See the libraries’ training losses. By default, it uses the supervised mean squared error.
eval_dataloader (None, torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Evaluation data loader(s) should provide a signal x or a tuple of (x, y) signal/measurement pairs.
args – Other positional arguments to pass to Trainer constructor. See
deepinv.Trainer
.kwargs – Keyword arguments to pass to Trainer constructor. See
deepinv.Trainer
.
- Returns:
Trained model.