train#
- class deepinv.train(model: ~torch.nn.modules.module.Module, physics: ~deepinv.physics.forward.Physics, optimizer: ~torch.optim.optimizer.Optimizer, train_dataloader: ~torch.utils.data.dataloader.DataLoader, epochs: int = 100, losses: ~deepinv.loss.loss.Loss | ~typing.List[~deepinv.loss.loss.Loss] = SupLoss( (metric): MSELoss() ), eval_dataloader: ~torch.utils.data.dataloader.DataLoader | None = None, *args, **kwargs)[source]#
Bases:
Alias function for training a model using
deepinv.Trainer
class.This function creates a Trainer instance and returns the trained model.
Warning
This function is deprecated and will be removed in future versions. Please use
deepinv.Trainer
instead.- Parameters:
model (torch.nn.Module) – Reconstruction network, which can be PnP, unrolled, artifact removal or any other custom reconstruction network.
physics (deepinv.physics.Physics, list[deepinv.physics.Physics]) – Forward operator(s) used by the reconstruction network.
epochs (int) – Number of training epochs. Default is 100.
optimizer (torch.nn.optim.Optimizer) – Torch optimizer for training the network.
train_dataloader (torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Train data loader(s) should provide a a signal x or a tuple of (x, y) signal/measurement pairs.
losses (deepinv.loss.Loss, list[deepinv.loss.Loss]) – Loss or list of losses used for training the model. See the libraries’ training losses. By default, it uses the supervised mean squared error.
eval_dataloader (None, torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Evaluation data loader(s) should provide a signal x or a tuple of (x, y) signal/measurement pairs.
args – Other positional arguments to pass to Trainer constructor. See
deepinv.Trainer()
.kwargs – Keyword arguments to pass to Trainer constructor. See
deepinv.Trainer()
.
- Returns:
Trained model.