Metric

class deepinv.loss.metric.Metric(metric: Callable | None = None, complex_abs: bool = False, train_loss: bool = False, reduction: str | None = None, norm_inputs: str | None = None)[source]

Bases: Module

Base class for metrics.

See docs for forward() below for more details.

To create a new metric, inherit from this class, override the function deepinv.metric.Metric.metric(), set lower_better attribute and optionally override the invert_metric method.

You can also directly use this baseclass to wrap an existing metric function, e.g. from torchmetrics, to benefit from our preprocessing. The metric function must reduce over all dims except the batch dim (see example).

Parameters:
  • metric (Callable) – metric function, it must reduce over all dims except batch dim. It must not reduce over batch dim. This is unused if the metric method is overrifden.

  • complex_abs (bool) – perform complex magnitude before passing data to metric function. If True, the data must either be of complex dtype or have size 2 in the channel dimension (usually the second dimension after batch).

  • train_loss (bool) – if higher is better, invert metric. If lower is better, does nothing.

  • reduction (str) – a method to reduce metric score over individual batch scores. mean: takes the mean, sum takes the sum, none or None no reduction will be applied (default).

  • norm_inputs (str) – normalize images before passing to metric. l2``normalizes by L2 spatial norm, ``min_max normalizes by min and max of each input.


Examples

Use Metric to wrap functional metrics such as from torchmetrics:

>>> from functools import partial
>>> from torchmetrics.functional.image import structural_similarity_index_measure
>>> from deepinv.loss.metric import Metric
>>> m = Metric(metric=partial(structural_similarity_index_measure, reduction='none'))
>>> x = x_net = ones(2, 3, 64, 64) # B,C,H,W
>>> m(x_net - 0.1, x)
tensor([0., 0.])
forward(x_net: Tensor | None = None, x: Tensor | None = None, *args, **kwargs) Tensor[source]

Metric forward pass.

Usually, the data passed is x_net, x i.e. estimate and target or only x_net for no-reference metric.

The forward pass also optionally calculates complex magnitude of images, performs normalisation, or inverts the metric to use it as a training loss (if by default higher is better).

By default, no reduction is performed in the batch dimension, but mean or sum reduction can be performed too.

All tensors should be of shape (B, ...) or (B, C, ...) where B is batch size and C is channels.

Parameters:
  • x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape (B, ...) or (B, C, ...).

  • x (torch.Tensor) – Reference image \(x\) (optional) of shape (B, ...) or (B, C, ...).

Return torch.Tensor:

calculated metric, the tensor size might be (1,) or (B,).

invert_metric(m: Tensor)[source]

Invert metric. Used where a higher=better metric is to be used in a training loss.

Parameters:

m (Tensor) – calculated metric

metric(x_net: Tensor | None = None, x: Tensor | None = None, *args, **kwargs) Tensor[source]

Calculate metric on data.

Override this function to implement your own metric. Always include args and kwargs arguments.

Parameters:
  • x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape (B, ...) or (B, C, ...).

  • x (torch.Tensor) – Reference image \(x\) (optional) of shape (B, ...) or (B, C, ...).

Return torch.Tensor:

calculated metric, the tensor size might be (1,) or (B,).