Metric#
- class deepinv.loss.metric.Metric(metric=None, complex_abs=False, train_loss=False, reduction=None, norm_inputs=None)[source]#
Bases:
ModuleBase class for metrics.
See docs for
forwardbelow for more details.To create a new metric, inherit from this class, override the
metric method, setlower_betterattribute and optionally override theinvert_metricmethod.You can also directly use this baseclass to wrap an existing metric function, e.g. from torchmetrics, to benefit from our preprocessing. The metric function must reduce over all dims except the batch dim (see example).
- Parameters:
metric (Callable) – metric function, it must reduce over all dims except batch dim. It must not reduce over batch dim. This is unused if the
metricmethod is overridden. Takes as inputx_netandxtensors and returns a tensor of metric scores.complex_abs (bool) – perform complex magnitude before passing data to metric function. If
True, the data must either be of complex dtype or have size 2 in the channel dimension (usually the second dimension after batch).train_loss (bool) – if higher is better, invert metric. If lower is better, does nothing.
reduction (str) – a method to reduce metric score over individual batch scores.
mean: takes the mean,sumtakes the sum,noneor None no reduction will be applied (default).norm_inputs (str) – normalize images before passing to metric.
l2``normalizes by L2 spatial norm, ``min_maxnormalizes by min and max of each input,clipclips to \([0,1]\),standardizestandardizes to same mean and std as ground truth,noneor None no reduction will be applied (default).
Examples
Use
Metricto wrap functional metrics such as from torchmetrics:>>> from functools import partial >>> from torchmetrics.functional.image import structural_similarity_index_measure >>> from deepinv.loss.metric import Metric >>> m = Metric(metric=partial(structural_similarity_index_measure, reduction='none')) >>> x = x_net = torch.ones(2, 3, 64, 64) # B,C,H,W >>> m(x_net - 0.1, x) tensor([0., 0.])
- __add__(other)[source]#
Sums two metrics via the + operation.
- Parameters:
other (deepinv.loss.metric.Metric) – other metric
- Returns:
deepinv.loss.metric.Metricsummed metric.
- forward(x_net=None, x=None, *args, **kwargs)[source]#
Metric forward pass.
Usually, the data passed is
x_net, xi.e. estimate and target or onlyx_netfor no-reference metric.The forward pass also optionally calculates complex magnitude of images, performs normalisation, or inverts the metric to use it as a training loss (if by default higher is better).
By default, no reduction is performed in the batch dimension, but mean or sum reduction can be performed too.
All tensors should be of shape
(B, ...)or(B, C, ...)whereBis batch size andCis channels.Note
If a full reference metric is used and a tensor is
None, a tensor of NaN will be returned instead.- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape
(B, ...)or(B, C, ...).x (torch.Tensor) – Reference image \(x\) (optional) of shape
(B, ...)or(B, C, ...).
- Return torch.Tensor:
calculated metric, the tensor size might be
(1,)or(B,).- Return type:
- invert_metric(m)[source]#
Invert metric. Used where a higher=better metric is to be used in a training loss.
- Parameters:
m (torch.Tensor) – calculated metric
- metric(x_net=None, x=None, *args, **kwargs)[source]#
Calculate metric on data.
Override this function to implement your own metric. Always include
argsandkwargsarguments. Do not perform reduction.- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape
(B, ...)or(B, C, ...).x (torch.Tensor) – Reference image \(x\) (optional) of shape
(B, ...)or(B, C, ...).
- Return torch.Tensor:
calculated unreduced metric of shape
(B,).- Return type: