Trainer#

class deepinv.Trainer(model, physics, optimizer, train_dataloader, ...)[source]#

Bases: object

Trainer class for training a reconstruction network.

See also

See the User Guide for more details and for how to adapt the trainer to your needs.

Training can be done by calling the deepinv.Trainer.train() method, whereas testing can be done by calling the deepinv.Trainer.test() method.

Training details are saved every ckp_interval epochs in the following format

save_path/yyyy-mm-dd_hh-mm-ss/ckp_{epoch}.pth.tar

where .pth.tar file contains a dictionary with the keys: epoch current epoch, state_dict the state dictionary of the model, loss the loss history, optimizer the state dictionary of the optimizer, and eval_metrics the evaluation metrics history.

Assuming that x is the ground-truth reference and y is the measurement and params is a dict of physics parameters, the dataloaders should return one of the following options:

  1. (x, y) or (x, y, params), which requires online_measurements=False (default) otherwise y will be ignored and new measurements will be generated online.

  2. (x) or (x, params), which requires online_measurements=True for generating measurements in an online manner (optionally with parameters) as y=physics(x) or y=physics(x, **params). Otherwise, first generate a dataset of (x,y) with deepinv.datasets.generate_dataset and then use option 1 above.

  3. If you have a dataset of measurements only (y) or (y, params) you should modify it such that it returns (torch.nan, y) or (torch.nan, y, params). Set online_measurements=False.

Note

The losses and evaluation metrics can be chosen from our training losses or our metrics

Custom losses can be used, as long as it takes as input (x, x_net, y, physics, model) and returns a tensor of length batch_size (i.e. reduction=None in the underlying metric, as we perform averaging to deal with uneven batch sizes), where x is the ground truth, x_net is the network reconstruction \(\inversef{y}{A}\), y is the measurement vector, physics is the forward operator and model is the reconstruction network. Note that not all inputs need to be used by the loss, e.g., self-supervised losses will not make use of x.

Custom metrics can also be used in the exact same way as custom losses.

Note

The training code can synchronize with Weights & Biases for logging and visualization by setting wandb_vis=True. The user can also customize the wandb setup by providing a dictionary with the setup for wandb.

Parameters are described below, grouped into Basics, Optimization, Evaluation, Physics Generators, Model Saving, Comparing with Pseudoinverse Baseline, Plotting, Verbose and Weights & Biases.

Basics:

Parameters:


Optimization:

Parameters:
  • optimizer (None, torch.optim.Optimizer) – Torch optimizer for training the network. Default is None.

  • epochs (int) – Number of training epochs. Default is 100. The trainer will perform gradient steps equal to the min(epochs*n_batches, max_batch_steps).

  • max_batch_steps (int) – Number of gradient steps per iteration. Default is 1e10. The trainer will perform batch steps equal to the min(epochs*n_batches, max_batch_steps).

  • scheduler (None, torch.optim.lr_scheduler.LRScheduler) – Torch scheduler for changing the learning rate across iterations. Default is None.

  • early_stop (bool) – If True, the training stops when the evaluation metric is not improving. Default is False. The user can modify the strategy for saving the best model by overriding the deepinv.Trainer.stop_criterion() method.

  • losses (deepinv.loss.Loss, list[deepinv.loss.Loss]) – Loss or list of losses used for training the model. Optionally wrap losses using a loss scheduler for more advanced training. See the libraries’ training losses. Where relevant, the underlying metric should have reduction=None as we perform the averaging using deepinv.utils.AverageMeter to deal with uneven batch sizes. Default is supervised loss.

  • grad_clip (float) – Gradient clipping value for the optimizer. If None, no gradient clipping is performed. Default is None.

  • optimizer_step_multi_dataset (bool) – If True, the optimizer step is performed once on all datasets. If False, the optimizer step is performed on each dataset separately.


Evaluation:

Parameters:
  • eval_dataloader (None, torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Evaluation data loader(s), see options 1 to 3 above for how we expect data to be provided.

  • metrics (Metric, list[Metric]) – Metric or list of metrics used for evaluating the model. They should have reduction=None as we perform the averaging using deepinv.utils.AverageMeter to deal with uneven batch sizes. See the libraries’ evaluation metrics. Default is PSNR.

  • eval_interval (int) – Number of epochs (or train iters, if log_train_batch=True) between each evaluation of the model on the evaluation set. Default is 1.

  • log_train_batch (bool) – if True, log train batch and eval-set metrics and losses for each train batch during training. This is useful for visualising train progress inside an epoch, not just over epochs. If False (default), log average over dataset per epoch (standard training).

Tip

If a validation dataloader eval_dataloader is provided, the trainer will also save the best model according to the first metric in the list, using the following format: save_path/yyyy-mm-dd_hh-mm-ss/ckp_best.pth.tar. The user can modify the strategy for saving the best model by overriding the deepinv.Trainer.save_best_model() method. The best model can be also loaded using the deepinv.Trainer.load_best_model() method.


Physics Generators:

Parameters:
  • physics_generator (None, deepinv.physics.generator.PhysicsGenerator) – Optional physics generator for generating the physics operators. If not None, the physics operators are randomly sampled at each iteration using the generator. Should be used in conjunction with online_measurements=True, no effect when online_measurements=False. Also see loop_random_online_physics. Default is None.

  • loop_random_online_physics (bool) – if True, resets the physics generator and noise model back to its initial state at the beginning of each epoch, so that the same measurements are generated each epoch. Requires shuffle=False in dataloaders. If False, generates new physics every epoch. Used in conjunction with online_measurements=True and physics_generator or noise model in physics, no effect when online_measurements=False. Default is False.

Warning

If the physics changes at each iteration for online measurements (e.g. if physics_generator is used to generate random physics operators or noise model is used), the generated measurements will randomly vary each epoch. If this is not desired (i.e. you want the same online measurements each epoch), set loop_random_online_physics=True. This resets the physics generator and noise model’s random generators every epoch.

Caveat: this requires shuffle=False in your dataloaders.

An alternative, safer solution is to generate and save params offline using deepinv.datasets.generate_dataset(). The params dict will then be automatically updated every time data is loaded.


Model Saving:

Parameters:
  • save_path (str) – Directory in which to save the trained model. Default is "." (current folder).

  • ckp_interval (int) – The model is saved every ckp_interval epochs. Default is 1.

  • ckpt_pretrained (str) – path of the pretrained checkpoint. If None (default), no pretrained checkpoint is loaded.


Comparison with Pseudoinverse Baseline:

Parameters:
  • compare_no_learning (bool) – If True, the no learning method is compared to the network reconstruction. Default is False.

  • no_learning_method (str) – Reconstruction method used for the no learning comparison. Options are 'A_dagger', 'A_adjoint', 'prox_l2', or 'y'. Default is 'A_dagger'. The user can also provide a custom method by overriding the no_learning_inference method. Default is 'A_dagger'.


Plotting:

Parameters:
  • plot_images (bool) – Plots reconstructions every ckp_interval epochs. Default is False.

  • plot_measurements (bool) – Plot the measurements y, default=`True`.

  • plot_convergence_metrics (bool) – Plot convergence metrics for model, default=`False`.

  • rescale_mode (str) – Rescale mode for plotting images. Default is 'clip'.


Verbose:

Parameters:
  • verbose (bool) – Output training progress information in the console. Default is True.

  • verbose_individual_losses (bool) – If True, the value of individual losses are printed during training. Otherwise, only the total loss is printed. Default is True.

  • show_progress_bar (bool) – Show a progress bar during training. Default is True.

  • check_grad (bool) – Compute and print the gradient norm at each iteration. Default is False.

  • display_losses_eval (bool) – If True, the losses are displayed during evaluation. Default is False.


Weights & Biases:

Parameters:
  • wandb_vis (bool) – Logs data onto Weights & Biases, see https://wandb.ai/ for more details. Default is False.

  • wandb_setup (dict) – Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details. Default is {}.

  • plot_interval (int) – Frequency of plotting images to wandb during train evaluation (at the end of each epoch). If 1, plots at each epoch. Default is 1.

  • freq_plot (int) – deprecated. Use plot_interval

check_clip_grad()[source]#

Check the gradient norm and perform gradient clipping if necessary.

compute_loss(physics, x, y, train=True, epoch=None, step=False)[source]#

Compute the loss and perform the backward pass.

It evaluates the reconstruction network, computes the losses, and performs the backward pass.

Parameters:
  • physics (deepinv.physics.Physics) – Current physics operator.

  • x (torch.Tensor) – Ground truth.

  • y (torch.Tensor) – Measurement.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

  • epoch (int) – current epoch.

  • step (bool) – Whether to perform an optimization step when computing the loss.

Returns:

(tuple) The network reconstruction x_net (for plotting and computing metrics) and the logs (for printing the training progress).

compute_metrics(x, x_net, y, physics, logs, train=True, epoch=None)[source]#

Compute the metrics.

It computes the metrics over the batch.

Parameters:
  • x (torch.Tensor) – Ground truth.

  • x_net (torch.Tensor) – Network reconstruction.

  • y (torch.Tensor) – Measurement.

  • physics (deepinv.physics.Physics) – Current physics operator.

  • logs (dict) – Dictionary containing the logs for printing the training progress.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

  • epoch (int) – current epoch.

Returns:

The logs with the metrics.

get_samples(iterators, g)[source]#

Get the samples.

This function returns a dictionary containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

the tuple returned by the get_samples_online or get_samples_offline function.

get_samples_offline(iterators, g)[source]#

Get the samples for the offline measurements.

In this setting, samples have been generated offline and are loaded from the dataloader. This function returns a tuple containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data (you can override this function to add custom data).

If the dataloader returns 3-tuples, this is assumed to be (x, y, params) where params is a dict of physics generator params. These params are then used to update the physics.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

a dictionary containing at least: the ground truth, the measurement, and the current physics operator.

get_samples_online(iterators, g)[source]#

Get the samples for the online measurements.

In this setting, a new sample is generated at each iteration by calling the physics operator. This function returns a dictionary containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

a tuple containing at least: the ground truth, the measurement, and the current physics operator.

load_best_model()[source]#

Load the best model.

It loads the model from the checkpoint saved during training.

Returns:

The model.

load_model(ckpt_pretrained=None)[source]#

Load model from checkpoint.

Parameters:

ckpt_pretrained (str) – checkpoint filename. If None, use checkpoint passed to class init. If not None, override checkpoint passed to class.

Returns:

if checkpoint loaded, return checkpoint dict, else return None

Return type:

dict

log_metrics_wandb(logs, step, train=True)[source]#

Log the metrics to wandb.

It logs the metrics to wandb.

Parameters:
  • logs (dict) – Dictionary containing the metrics to log.

  • step (int) – Current step to log. If Trainer.log_train_batch=True, this is the batch iteration, if False (default), this is the epoch.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

model_inference(y, physics, x=None, train=True, **kwargs)[source]#

Perform the model inference.

It returns the network reconstruction given the samples.

Parameters:
Returns:

The network reconstruction.

no_learning_inference(y, physics)[source]#

Perform the no learning inference.

By default it returns the (linear) pseudo-inverse reconstruction given the measurement.

Parameters:
Returns:

Reconstructed image.

plot(epoch, physics, x, y, x_net, train=True)[source]#

Plot and optinally save the reconstructions.

Parameters:
reset_metrics()[source]#

Reset the metrics.

save_best_model(epoch, train_ite, **kwargs)[source]#

Save the best model using validation metrics.

By default, uses validation based on first metric. Override this method to provide custom criterion.

Parameters:
  • epoch (int) – Current epoch.

  • train_ite (int) – Current training batch iteration, equal to (current epoch \(\times\) number of batches) + current batch within epoch

save_model(filename, epoch, state={})[source]#

Save the model.

It saves the model every ckp_interval epochs.

Parameters:
  • epoch (int) – Current epoch.

  • eval_metrics (None, float) – Evaluation metrics across epochs.

  • state (dict) – custom objects to save with model

setup_train(train=True, **kwargs)[source]#

Set up the training process.

It initializes the wandb logging, the different metrics, the save path, the physics and dataloaders, and the pretrained checkpoint if given.

Parameters:

train (bool) – whether model is being trained.

step(epoch, progress_bar, train_ite=None, train=True, last_batch=False)[source]#

Train/Eval a batch.

It performs the forward pass, the backward pass, and the evaluation at each iteration.

Parameters:
  • epoch (int) – Current epoch.

  • progress_bartqdm progress bar.

  • train_ite (int) – train iteration, only needed for logging if Trainer.log_train_batch=True

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

  • last_batch (bool) – If True, the last batch of the epoch is being processed.

Returns:

The current physics operator, the ground truth, the measurement, and the network reconstruction.

stop_criterion(epoch, train_ite, **kwargs)[source]#

Stop criterion for early stopping.

By default, stops optimization when first eval metric doesn’t improve in the last 3 evaluations.

Override this method to early stop on a custom condition.

Parameters:
  • epoch (int) – Current epoch.

  • train_ite (int) – Current training batch iteration, equal to (current epoch \(\times\) number of batches) + current batch within epoch

  • metric_history (dict) – Dictionary containing the metrics history, with the metric name as key.

  • metrics (list) – List of metrics used for evaluation.

test(test_dataloader, save_path=None, compare_no_learning=True, log_raw_metrics=False)[source]#

Test the model, compute metrics and plot images.

Parameters:
  • test_dataloader (torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Test data loader(s), see options 1 to 3 above for how we expect data to be provided.

  • save_path (str) – Directory in which to save the plotted images.

  • compare_no_learning (bool) – If True, the linear reconstruction is compared to the network reconstruction.

  • log_raw_metrics (bool) – if True, also return non-aggregated metrics as a list.

Returns:

dict of metrics results with means and stds.

Return type:

dict

train()[source]#

Train the model.

It performs the training process, including the setup, the evaluation, the forward and backward passes, and the visualization.

Returns:

The trained model.

Examples using Trainer:#

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks

Remote sensing with satellite images

Remote sensing with satellite images

Tour of MRI functionality in DeepInverse

Tour of MRI functionality in DeepInverse

Training a reconstruction network.

Training a reconstruction network.

Patch priors for limited-angle computed tomography

Patch priors for limited-angle computed tomography

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised MRI reconstruction with Artifact2Artifact

Image transformations for Equivariant Imaging

Image transformations for Equivariant Imaging

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning from incomplete measurements of multiple operators.

Self-supervised learning from incomplete measurements of multiple operators.

Self-supervised denoising with the Neighbor2Neighbor loss.

Self-supervised denoising with the Neighbor2Neighbor loss.

Self-supervised denoising with the Generalized R2R loss.

Self-supervised denoising with the Generalized R2R loss.

Self-supervised learning with measurement splitting

Self-supervised learning with measurement splitting

Self-supervised denoising with the SURE loss.

Self-supervised denoising with the SURE loss.

Self-supervised denoising with the UNSURE loss.

Self-supervised denoising with the UNSURE loss.

Deep Equilibrium (DEQ) algorithms for image deblurring

Deep Equilibrium (DEQ) algorithms for image deblurring

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Learned iterative custom prior

Learned iterative custom prior

Learned Primal-Dual algorithm for CT scan.

Learned Primal-Dual algorithm for CT scan.

Unfolded Chambolle-Pock for constrained image inpainting

Unfolded Chambolle-Pock for constrained image inpainting

Vanilla Unfolded algorithm for super-resolution

Vanilla Unfolded algorithm for super-resolution