test#
- class deepinv.test(model, test_dataloader, physics, metrics=PSNR(), online_measurements=False, physics_generator=None, device='cpu', plot_images=False, save_folder=None, plot_convergence_metrics=False, verbose=True, rescale_mode='clip', show_progress_bar=True, no_learning_method='A_dagger', **kwargs)[source]#
Bases:
Tests a reconstruction model (algorithm or network).
This function computes the chosen metrics of the reconstruction network on the test set, and optionally plots the reconstructions as well as the metrics computed along the iterations. Note that by default only the first batch is plotted.
- Parameters:
model (torch.nn.Module) – Reconstruction network, which can be PnP, unrolled, artifact removal or any other custom reconstruction network (unfolded, plug-and-play, etc).
test_dataloader (torch.utils.data.DataLoader) – Test data loader, which should provide a tuple of (x, y) pairs. See datasets for more details.
physics (deepinv.physics.Physics, list[deepinv.physics.Physics]) – Forward operator(s) used by the reconstruction network at test time.
metrics (deepinv.loss.Loss, list[deepinv.Loss]) – Metric or list of metrics used for evaluating the model. See the libraries’ evaluation metrics.
online_measurements (bool) – Generate the measurements in an online manner at each iteration by calling
physics(x)
.physics_generator (None, deepinv.physics.generator.PhysicsGenerator) – Optional physics generator for generating the physics operators. If not None, the physics operators are randomly sampled at each iteration using the generator. Should be used in conjunction with
online_measurements=True
.device (torch.device) – gpu or cpu.
plot_images (bool) – Plot the ground-truth and estimated images.
save_folder (str) – Directory in which to save plotted reconstructions. Images are saved in the
save_folder/images
directoryplot_convergence_metrics (bool) – plot the metrics to be plotted w.r.t iteration.
verbose (bool) – Output training progress information in the console.
plot_measurements (bool) – Plot the measurements y. default=True.
show_progress_bar (bool) – Show progress bar.
no_learning_method (str) – Reconstruction method used for the no learning comparison. Options are
'A_dagger'
,'A_adjoint'
,'prox_l2'
, or'y'
. Default is'A_dagger'
. The user can modify the no-learning method by overwriting thedeepinv.Trainer.no_learning_inference()
method
- Returns:
A dictionary with the metrics computed on the test set, where the keys are the metric names, and include the average and standard deviation of the metric.
Examples using test
:#
Training a reconstruction network.
Image deblurring with custom deep explicit prior.
DPIR method for PnP image deblurring.
Regularization by Denoising (RED) for Super-Resolution.
Unfolded Chambolle-Pock for constrained image inpainting