test#

class deepinv.test(model, test_dataloader, physics, metrics=PSNR(), online_measurements=False, physics_generator=None, device='cpu', plot_images=False, save_folder=None, plot_convergence_metrics=False, verbose=True, rescale_mode='clip', show_progress_bar=True, no_learning_method='A_dagger', **kwargs)[source]#

Bases:

Tests a reconstruction model (algorithm or network).

This function computes the chosen metrics of the reconstruction network on the test set, and optionally plots the reconstructions as well as the metrics computed along the iterations. Note that by default only the first batch is plotted.

Parameters:
  • model (torch.nn.Module) – Reconstruction network, which can be PnP, unrolled, artifact removal or any other custom reconstruction network (unfolded, plug-and-play, etc).

  • test_dataloader (torch.utils.data.DataLoader) – Test data loader, which should provide a tuple of (x, y) pairs. See datasets for more details.

  • physics (deepinv.physics.Physics, list[deepinv.physics.Physics]) – Forward operator(s) used by the reconstruction network at test time.

  • metrics (deepinv.loss.Loss, list[deepinv.Loss]) – Metric or list of metrics used for evaluating the model. See the libraries’ evaluation metrics.

  • online_measurements (bool) – Generate the measurements in an online manner at each iteration by calling physics(x).

  • physics_generator (None, deepinv.physics.generator.PhysicsGenerator) – Optional physics generator for generating the physics operators. If not None, the physics operators are randomly sampled at each iteration using the generator. Should be used in conjunction with online_measurements=True.

  • device (torch.device) – gpu or cpu.

  • plot_images (bool) – Plot the ground-truth and estimated images.

  • save_folder (str) – Directory in which to save plotted reconstructions. Images are saved in the save_folder/images directory

  • plot_convergence_metrics (bool) – plot the metrics to be plotted w.r.t iteration.

  • verbose (bool) – Output training progress information in the console.

  • plot_measurements (bool) – Plot the measurements y. default=True.

  • show_progress_bar (bool) – Show progress bar.

  • no_learning_method (str) – Reconstruction method used for the no learning comparison. Options are 'A_dagger', 'A_adjoint', 'prox_l2', or 'y'. Default is 'A_dagger'. The user can modify the no-learning method by overwriting the deepinv.Trainer.no_learning_inference() method

Returns:

A dictionary with the metrics computed on the test set, where the keys are the metric names, and include the average and standard deviation of the metric.

Examples using test:#

Training a reconstruction network.

Training a reconstruction network.

Image deblurring with custom deep explicit prior.

Image deblurring with custom deep explicit prior.

DPIR method for PnP image deblurring.

DPIR method for PnP image deblurring.

Regularization by Denoising (RED) for Super-Resolution.

Regularization by Denoising (RED) for Super-Resolution.

Unfolded Chambolle-Pock for constrained image inpainting

Unfolded Chambolle-Pock for constrained image inpainting