SSIM#

class deepinv.loss.metric.SSIM(multiscale=False, max_pixel=1.0, torchmetric_kwargs: dict = {}, **kwargs)[source]#

Bases: Metric

Structural Similarity Index (SSIM) metric using torchmetrics.

Calculates the SSIM \(\text{SSIM}(\hat{x},x)\) where \(\hat{x}=\inverse{y}\). See https://en.wikipedia.org/wiki/Structural_similarity for more information.

To set the max pixel on the fly (as is the case in fastMRI evaluation code), set max_pixel=None.

Note

By default, no reduction is performed in the batch dimension.

Example:

>>> import torch
>>> from deepinv.loss.metric import SSIM
>>> m = SSIM()
>>> x_net = x = torch.ones(3, 2, 32, 32) # B,C,H,W
>>> m(x_net, x)
tensor([1., 1., 1.])
Parameters:
  • multiscale (bool) – if True, computes the multiscale SSIM. Default: False.

  • max_pixel (float) – maximum pixel value. If None, uses max pixel value of x.

  • torchmetric_kwargs (dict) – kwargs for torchmetrics SSIM as dict. See https://lightning.ai/docs/torchmetrics/stable/image/structural_similarity.html

  • complex_abs (bool) – perform complex magnitude before passing data to metric function. If True, the data must either be of complex dtype or have size 2 in the channel dimension (usually the second dimension after batch).

  • train_loss (bool) – use metric as a training loss, by returning one minus the metric.

  • reduction (str) – a method to reduce metric score over individual batch scores. mean: takes the mean, sum takes the sum, none or None no reduction will be applied (default).

  • norm_inputs (str) – normalize images before passing to metric. l2``normalizes by L2 spatial norm, ``min_max normalizes by min and max of each input.

invert_metric(m)[source]#

Invert metric. Used where a higher=better metric is to be used in a training loss.

Parameters:

m (Tensor) – calculated metric

metric(x_net, x, *args, **kwargs)[source]#

Calculate metric on data.

Override this function to implement your own metric. Always include args and kwargs arguments.

Parameters:
  • x_net (torch.Tensor) – Reconstructed image \(\hat{x}=\inverse{y}\) of shape (B, ...) or (B, C, ...).

  • x (torch.Tensor) – Reference image \(x\) (optional) of shape (B, ...) or (B, C, ...).

Return torch.Tensor:

calculated metric, the tensor size might be (1,) or (B,).