Trainer

class deepinv.Trainer(model, physics, optimizer, train_dataloader, ...)[source]

Bases: object

Trainer class for training a reconstruction network.

Training can be done by calling the deepinv.Trainer.train() method, whereas testing can be done by calling the deepinv.Trainer.test() method.

Training details are saved every ckp_interval epochs in the following format

save_path/yyyy-mm-dd_hh-mm-ss/ckp_{epoch}.pth.tar

where .pth.tar file contains a dictionary with the keys: epoch current epoch, state_dict the state dictionary of the model, loss the loss history, optimizer the state dictionary of the optimizer, and ``eval_metrics` the evaluation metrics history.

The class provides a flexible training loop that can be customized by the user. In particular, the user can rewrite the deepinv.Trainer.compute_loss() method to define their custom training step without having to write all the training code from scratch:

class CustomTrainer(Trainer):
    def compute_loss(self, physics, x, y, train=True, epoch: int = None):
        logs = {}

        self.optimizer.zero_grad() # Zero the gradients

        # Evaluate reconstruction network
        x_net = self.model_inference(y=y, physics=physics)

        # Compute the losses
        loss_total = 0
        for k, l in enumerate(self.losses):
            loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch)
            loss_total += loss.mean()

        metric = self.logs_total_loss_train if train else self.logs_total_loss_eval
        metric.update(loss_total.item())
        logs[f"TotalLoss"] = metric.avg

        if train:
            loss_total.backward()  # Backward the total loss
            self.optimizer.step() # Optimizer step

        return x_net, logs

If the user wants to change the way the metrics are computed, they can rewrite the deepinv.Trainer.compute_metrics() method. The user can also change the way samples are generated by overriding either the deepinv.Trainer.get_samples_online() method or the deepinv.Trainer.get_samples_offline() method, e.g. to change the physics parameters on-the-fly with parameters from the dataset.

For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements. In this case, to update the deepinv.physics.Physics() parameters accordingly, a potential implementation would be:

class CustomTrainer(Trainer):
    def get_samples_offline(self, iterators, g):
        # Suppose your dataset returns per-sample masks, e.g. in MRI
        x, y, mask = next(iterators[g])

        # Suppose physics has class params such as DecomposablePhysics or MRI
        physics = self.physics[g]

        # Update physics parameters deterministically (i.e. not using a random generator)
        physics.update_parameters(mask=mask.to(self.device))

        return x.to(self.device), y.to(self.device), physics

Note

The training code can synchronize with Weights & Biases for logging and visualization by setting wandb_vis=True. The user can also customize the wandb setup by providing a dictionary with the setup for wandb.

Note

The losses and evaluation metrics can be chosen from the libraries’ training losses, or can be a custom loss function, as long as it takes as input (x, x_net, y, physics, model) and returns a scalar, where x is the ground reconstruction, x_net is the network reconstruction \(\inversef{y}{A}\), y is the measurement vector, physics is the forward operator and model is the reconstruction network. Note that not all inpus need to be used by the loss, e.g., self-supervised losses will not make use of x.

Warning

If a physics generator is used to generate params for online measurements, the generated params will vary each epoch. If this is not desired (you want the same online measurements each epoch), set loop_physics_generator=True. Caveat: this requires shuffle=False in your dataloaders. An alternative solution is to generate and save params offline using deepinv.datasets.generate_dataset().

Parameters:
  • model (torch.nn.Module) – Reconstruction network, which can be PnP, unrolled, artifact removal or any other custom reconstruction network.

  • physics (deepinv.physics.Physics, list[deepinv.physics.Physics]) – Forward operator(s) used by the reconstruction network.

  • epochs (int) – Number of training epochs. Default is 100.

  • optimizer (torch.nn.optim.Optimizer) – Torch optimizer for training the network.

  • train_dataloader (torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Train data loader(s) should provide a a signal x or a tuple of (x, y) signal/measurement pairs.

  • losses (deepinv.loss.Loss, list[deepinv.loss.Loss]) – Loss or list of losses used for training the model. Optionally wrap losses using a loss scheduler for more advanced training. See the libraries’ training losses. By default, it uses the supervised mean squared error.

  • eval_dataloader (None, torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Evaluation data loader(s) should provide a signal x or a tuple of (x, y) signal/measurement pairs.

  • scheduler (None, torch.optim.lr_scheduler) – Torch scheduler for changing the learning rate across iterations.

  • online_measurements (bool) – Generate the measurements in an online manner at each iteration by calling physics(x). This results in a wider range of measurements if the physics’ parameters, such as parameters of the forward operator or noise realizations, can change between each sample; the measurements are loaded from the training dataset.

  • physics_generator (None, deepinv.physics.generator.PhysicsGenerator) – Optional physics generator for generating the physics operators. If not None, the physics operators are randomly sampled at each iteration using the generator. Should be used in conjunction with online_measurements=True. Also see loop_physics_generator.

  • metrics (Metric, list[Metric]) – Metric or list of metrics used for evaluating the model. See the libraries’ evaluation metrics.

  • grad_clip (float) – Gradient clipping value for the optimizer. If None, no gradient clipping is performed.

  • ckp_interval (int) – The model is saved every ckp_interval epochs.

  • eval_interval (int) – Number of epochs between each evaluation of the model on the evaluation set.

  • save_path (str) – Directory in which to save the trained model.

  • device (str) – Device on which to run the training (e.g., ‘cuda’ or ‘cpu’).

  • verbose (bool) – Output training progress information in the console.

  • show_progress_bar (bool) – Show a progress bar during training.

  • plot_images (bool) – Plots reconstructions every ckp_interval epochs.

  • wandb_vis (bool) – Logs data onto Weights & Biases, see https://wandb.ai/ for more details.

  • wandb_setup (dict) – Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details.

  • plot_measurements (bool) – Plot the measurements y. default=True.

  • check_grad (bool) – Compute and print the gradient norm at each iteration.

  • ckpt_pretrained (str) – path of the pretrained checkpoint. If None, no pretrained checkpoint is loaded.

  • freq_plot (int) – Frequency of plotting images to wandb during train evaluation (at the end of each epoch). If 1, plots at each epoch.

  • verbose_individual_losses (bool) – If True, the value of individual losses are printed during training. Otherwise, only the total loss is printed.

  • display_losses_eval (bool) – If True, the losses are displayed during evaluation.

  • rescale_mode (str) – Rescale mode for plotting images. Default is 'clip'.

  • compare_no_learning (bool) – If True, the no learning method is compared to the network reconstruction.

  • no_learning_method (str) – Reconstruction method used for the no learning comparison. Options are 'A_dagger', 'A_adjoint', 'prox_l2', or 'y'. Default is 'A_dagger'. The user can also provide a custom method by overriding the deepinv.Trainer.no_learning_inference() method.

  • loop_physics_generator (bool) – if True, resets the physics generator back to its initial state at the beginning of each epoch, so that the same measurements are generated each epoch. Requires shuffle=False in dataloaders. If False, generates new physics every epoch. Used in conjunction with physics_generator.

check_clip_grad()[source]

Check the gradient norm and perform gradient clipping if necessary.

compute_loss(physics, x, y, train=True, epoch: int | None = None)[source]

Compute the loss and perform the backward pass.

It evaluates the reconstruction network, computes the losses, and performs the backward pass.

Parameters:
Returns:

(tuple) The network reconstruction x_net (for plotting and computing metrics) and the logs (for printing the training progress).

compute_metrics(x, x_net, y, physics, logs, train=True, epoch: int | None = None)[source]

Compute the metrics.

It computes the metrics over the batch.

Parameters:
  • x (torch.Tensor) – Ground truth.

  • x_net (torch.Tensor) – Network reconstruction.

  • y (torch.Tensor) – Measurement.

  • physics (deepinv.physics.Physics) – Current physics operator.

  • logs (dict) – Dictionary containing the logs for printing the training progress.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

  • epoch (int) – current epoch.

Returns:

The logs with the metrics.

get_samples(iterators, g)[source]

Get the samples.

This function returns a dictionary containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

the tuple returned by the get_samples_online or get_samples_offline function.

get_samples_offline(iterators, g)[source]

Get the samples for the offline measurements.

In this setting, samples have been generated offline and are loaded from the dataloader. This function returns a tuple containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data (you can override this function to add custom data).

If the dataloader returns 3-tuples, this is assumed to be (x, y, params) where params is a dict of physics generator params. These params are then used to update the physics.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

a dictionary containing at least: the ground truth, the measurement, and the current physics operator.

get_samples_online(iterators, g)[source]

Get the samples for the online measurements.

In this setting, a new sample is generated at each iteration by calling the physics operator. This function returns a dictionary containing necessary data for the model inference. It needs to contain the measurement, the ground truth, and the current physics operator, but can also contain additional data.

Parameters:
  • iterators (list) – List of dataloader iterators.

  • g (int) – Current dataloader index.

Returns:

a tuple containing at least: the ground truth, the measurement, and the current physics operator.

log_metrics_wandb(logs, epoch, train=True)[source]

Log the metrics to wandb.

It logs the metrics to wandb.

Parameters:
  • logs (dict) – Dictionary containing the metrics to log.

  • epoch (int) – Current epoch.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

model_inference(y, physics, x=None, train=True, **kwargs)[source]

Perform the model inference.

It returns the network reconstruction given the samples.

Parameters:
Returns:

The network reconstruction.

no_learning_inference(y, physics)[source]

Perform the no learning inference.

By default it returns the (linear) pseudo-inverse reconstruction given the measurement.

Parameters:
Returns:

Reconstructed image.

plot(epoch, physics, x, y, x_net, train=True)[source]

Plot and optinally save the reconstructions.

Parameters:
reset_metrics()[source]

Reset the metrics.

save_model(epoch, eval_metrics=None, state={})[source]

Save the model.

It saves the model every ckp_interval epochs.

Parameters:
  • epoch (int) – Current epoch.

  • eval_metrics (None, float) – Evaluation metrics across epochs.

  • state (dict) – custom objects to save with model

setup_train(train=True, **kwargs)[source]

Set up the training process.

It initializes the wandb logging, the different metrics, the save path, the physics and dataloaders, and the pretrained checkpoint if given.

step(epoch, progress_bar, train=True, last_batch=False)[source]

Train/Eval a batch.

It performs the forward pass, the backward pass, and the evaluation at each iteration.

Parameters:
  • epoch (int) – Current epoch.

  • progress_bar (tqdm) – Progress bar.

  • train (bool) – If True, the model is trained, otherwise it is evaluated.

  • last_batch (bool) – If True, the last batch of the epoch is being processed.

Returns:

The current physics operator, the ground truth, the measurement, and the network reconstruction.

test(test_dataloader, save_path=None, compare_no_learning=True)[source]

Test the model.

Parameters:
  • test_dataloader (torch.utils.data.DataLoader, list[torch.utils.data.DataLoader]) – Test data loader(s) should provide a a signal x or a tuple of (x, y) signal/measurement pairs.

  • save_path (str) – Directory in which to save the trained model.

  • compare_no_learning (bool) – If True, the linear reconstruction is compared to the network reconstruction.

Returns:

The trained model.

train()[source]

Train the model.

It performs the training process, including the setup, the evaluation, the forward and backward passes, and the visualization.

Returns:

The trained model.

Examples using Trainer:

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks

Training a reconstruction network.

Training a reconstruction network.

Patch priors for limited-angle computed tomography

Patch priors for limited-angle computed tomography

Self-supervised learning with measurement splitting

Self-supervised learning with measurement splitting

Image transformations for Equivariant Imaging

Image transformations for Equivariant Imaging

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised denoising with the UNSURE loss.

Self-supervised denoising with the UNSURE loss.

Self-supervised denoising with the SURE loss.

Self-supervised denoising with the SURE loss.

Self-supervised denoising with the Neighbor2Neighbor loss.

Self-supervised denoising with the Neighbor2Neighbor loss.

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning from incomplete measurements of multiple operators.

Self-supervised learning from incomplete measurements of multiple operators.

Vanilla Unfolded algorithm for super-resolution

Vanilla Unfolded algorithm for super-resolution

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Deep Equilibrium (DEQ) algorithms for image deblurring

Deep Equilibrium (DEQ) algorithms for image deblurring

Learned iterative custom prior

Learned iterative custom prior

Learned Primal-Dual algorithm for CT scan.

Learned Primal-Dual algorithm for CT scan.

Unfolded Chambolle-Pock for constrained image inpainting

Unfolded Chambolle-Pock for constrained image inpainting