Metrics

This package contains popular metrics for inverse problems.

Metrics are generally used to evaluate the performance of a model, or as the distance function inside a loss function.

Introduction

All metrics inherit from the base class deepinv.loss.metric.Metric(), which is a torch.nn.Module().

deepinv.loss.metric.Metric

Base class for metrics.

All metrics take either x_net, x for a full-reference metric or x_net for a no-reference metric.

All metrics can perform a standard set of pre and post processing, including operating on complex numbers, normalisation and reduction. See deepinv.loss.metric.Metric for more details.

Note

By default, metrics do not reduce over the batch dimension, as the usual usage is to average the metrics over a dataset yourself. This discourages averaging over metrics which might in turn have averaged over uneven batch sizes. Note we provide deepinv.utils.AverageMeter to easily keep track of the average of metrics. For example, we use this in our trainer deepinv.training.Trainer.

However, you can use the reduction argument to perform reduction, e.g. if you want a single metric calculation rather than over a dataset.

All metrics can either be used directly as metrics, or as the backbone for training losses. To do this, wrap the metric in a suitable loss such as deepinv.loss.SupLoss or deepinv.loss.MCLoss. In this way, deepinv.loss.metric.MSE replaces torch.nn.MSELoss and deepinv.loss.metric.MAE replaces torch.nn.L1Loss, and you can use these in a loss like SupLoss(metric=MSE()).

Note

For some metrics, higher is better; for these, you must also set train_loss=True.

Note

For convenience, you can also import metrics directly from deepinv.metric or deepinv.loss.

Finally, you can also wrap existing metric functions using Metric(metric=f), see deepinv.loss.metric.Metric for an example.

Example:

>>> import torch
>>> import deepinv as dinv
>>> m = dinv.metric.SSIM()
>>> x = torch.ones(2, 3, 16, 16) # B,C,H,W
>>> x_hat = x + 0.01
>>> m(x_hat, x) # Calculate metric for each image in batch
tensor([1.0000, 1.0000])
>>> m = dinv.metric.SSIM(reduction="sum")
>>> m(x_hat, x) # Sum over batch
tensor(1.9999)
>>> l = dinv.loss.MCLoss(metric=dinv.metric.SSIM(train_loss=True, reduction="mean")) # Use SSIM for training

Distortion metrics

We implement popular distortion metrics (see The Perception-Distortion Tradeoff for an explanation of distortion vs perceptual metrics):

deepinv.loss.metric.MSE

Mean Squared Error metric.

deepinv.loss.metric.NMSE

Normalised Mean Squared Error metric.

deepinv.loss.metric.MAE

Mean Absolute Error metric.

deepinv.loss.metric.PSNR

Peak Signal-to-Noise Ratio (PSNR) metric.

deepinv.loss.metric.SSIM

Structural Similarity Index (SSIM) metric using torchmetrics.

deepinv.loss.metric.QNR

Quality with No Reference (QNR) metric for pansharpening.

deepinv.loss.metric.L1L2

Combined L2 and L1 metric.

Perceptual metrics

We implement popular perceptual metrics:

deepinv.loss.metric.LPIPS

Learned Perceptual Image Patch Similarity (LPIPS) metric.

deepinv.loss.metric.NIQE

Natural Image Quality Evaluator (NIQE) metric.

Utils

A set of popular distances that can be used by the supervised and self-supervised losses.

deepinv.loss.metric.LpNorm

\(\ell_p\) metric for \(p>0\).