Metric
- class deepinv.loss.metric.Metric(metric: Callable | None = None, complex_abs: bool = False, train_loss: bool = False, reduction: str | None = None, norm_inputs: str | None = None)[source]
Bases:
Module
Base class for metrics.
See docs for
forward()
below for more details.To create a new metric, inherit from this class, override the function
deepinv.metric.Metric.metric()
, set lower_better attribute and optionally override theinvert_metric
method.You can also directly use this baseclass to wrap an existing metric function, e.g. from torchmetrics, to benefit from our preprocessing. The metric function must reduce over all dims except the batch dim (see example).
- Parameters:
metric (Callable) – metric function, it must reduce over all dims except batch dim. It must not reduce over batch dim. This is unused if the
metric
method is overrifden.complex_abs (bool) – perform complex magnitude before passing data to metric function. If
True
, the data must either be of complex dtype or have size 2 in the channel dimension (usually the second dimension after batch).train_loss (bool) – if higher is better, invert metric. If lower is better, does nothing.
reduction (str) – a method to reduce metric score over individual batch scores.
mean
: takes the mean,sum
takes the sum,none
or None no reduction will be applied (default).norm_inputs (str) – normalize images before passing to metric.
l2``normalizes by L2 spatial norm, ``min_max
normalizes by min and max of each input.
Examples
Use
Metric
to wrap functional metrics such as from torchmetrics:>>> from functools import partial >>> from torchmetrics.functional.image import structural_similarity_index_measure >>> from deepinv.loss.metric import Metric >>> m = Metric(metric=partial(structural_similarity_index_measure, reduction='none')) >>> x = x_net = ones(2, 3, 64, 64) # B,C,H,W >>> m(x_net - 0.1, x) tensor([0., 0.])
- forward(x_net: Tensor | None = None, x: Tensor | None = None, *args, **kwargs) Tensor [source]
Metric forward pass.
Usually, the data passed is
x_net, x
i.e. estimate and target or onlyx_net
for no-reference metric.The forward pass also optionally calculates complex magnitude of images, performs normalisation, or inverts the metric to use it as a training loss (if by default higher is better).
By default, no reduction is performed in the batch dimension, but mean or sum reduction can be performed too.
All tensors should be of shape
(B, ...)
or(B, C, ...)
whereB
is batch size andC
is channels.- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape
(B, ...)
or(B, C, ...)
.x (torch.Tensor) – Reference image \(x\) (optional) of shape
(B, ...)
or(B, C, ...)
.
- Return torch.Tensor:
calculated metric, the tensor size might be
(1,)
or(B,)
.
- invert_metric(m: Tensor)[source]
Invert metric. Used where a higher=better metric is to be used in a training loss.
- Parameters:
m (Tensor) – calculated metric
- metric(x_net: Tensor | None = None, x: Tensor | None = None, *args, **kwargs) Tensor [source]
Calculate metric on data.
Override this function to implement your own metric. Always include
args
andkwargs
arguments.- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape
(B, ...)
or(B, C, ...)
.x (torch.Tensor) – Reference image \(x\) (optional) of shape
(B, ...)
or(B, C, ...)
.
- Return torch.Tensor:
calculated metric, the tensor size might be
(1,)
or(B,)
.