Metrics
This package contains popular metrics for inverse problems.
Metrics are generally used to evaluate the performance of a model, or as the distance function inside a loss function.
Introduction
All metrics inherit from the base class deepinv.loss.metric.Metric()
, which is a torch.nn.Module()
.
Base class for metrics. |
All metrics take either x_net, x
for a full-reference metric or x_net
for a no-reference metric.
All metrics can perform a standard set of pre and post processing, including
operating on complex numbers, normalisation and reduction. See deepinv.loss.metric.Metric
for more details.
Note
By default, metrics do not reduce over the batch dimension, as the usual usage is to average the metrics over a dataset yourself.
This discourages averaging over metrics which might in turn have averaged over uneven batch sizes.
Note we provide deepinv.utils.AverageMeter
to easily keep track of the average of metrics.
For example, we use this in our trainer deepinv.training.Trainer
.
However, you can use the reduction
argument to perform reduction, e.g. if you want a single metric calculation rather than over a dataset.
All metrics can either be used directly as metrics, or as the backbone for training losses.
To do this, wrap the metric in a suitable loss such as deepinv.loss.SupLoss
or deepinv.loss.MCLoss
.
In this way, deepinv.loss.metric.MSE
replaces torch.nn.MSELoss
and deepinv.loss.metric.MAE
replaces torch.nn.L1Loss
,
and you can use these in a loss like SupLoss(metric=MSE())
.
Note
For some metrics, higher is better; for these, you must also set train_loss=True
.
Note
For convenience, you can also import metrics directly from deepinv.metric
or deepinv.loss
.
Finally, you can also wrap existing metric functions using Metric(metric=f)
, see deepinv.loss.metric.Metric
for an example.
Example:
>>> import torch
>>> import deepinv as dinv
>>> m = dinv.metric.SSIM()
>>> x = torch.ones(2, 3, 16, 16) # B,C,H,W
>>> x_hat = x + 0.01
>>> m(x_hat, x) # Calculate metric for each image in batch
tensor([1.0000, 1.0000])
>>> m = dinv.metric.SSIM(reduction="sum")
>>> m(x_hat, x) # Sum over batch
tensor(1.9999)
>>> l = dinv.loss.MCLoss(metric=dinv.metric.SSIM(train_loss=True, reduction="mean")) # Use SSIM for training
Distortion metrics
We implement popular distortion metrics (see The Perception-Distortion Tradeoff for an explanation of distortion vs perceptual metrics):
Mean Squared Error metric. |
|
Normalised Mean Squared Error metric. |
|
Mean Absolute Error metric. |
|
Peak Signal-to-Noise Ratio (PSNR) metric. |
|
Structural Similarity Index (SSIM) metric using torchmetrics. |
|
Quality with No Reference (QNR) metric for pansharpening. |
|
Combined L2 and L1 metric. |
Perceptual metrics
We implement popular perceptual metrics:
Learned Perceptual Image Patch Similarity (LPIPS) metric. |
|
Natural Image Quality Evaluator (NIQE) metric. |
Utils
A set of popular distances that can be used by the supervised and self-supervised losses.
\(\ell_p\) metric for \(p>0\). |