SSIM#
- class deepinv.loss.metric.SSIM(multiscale=False, max_pixel=1.0, torchmetric_kwargs={}, **kwargs)[source]#
Bases:
Metric
Structural Similarity Index (SSIM) metric using torchmetrics.
Calculates \(\text{SSIM}(\hat{x},x)\) where \(\hat{x}=\inverse{y}\). See https://en.wikipedia.org/wiki/Structural_similarity for more information.
To set the max pixel on the fly (as is the case in fastMRI evaluation code), set
max_pixel=None
.Note
By default, no reduction is performed in the batch dimension.
- Example:
>>> import torch >>> from deepinv.loss.metric import SSIM >>> m = SSIM() >>> x_net = x = torch.ones(3, 2, 32, 32) # B,C,H,W >>> m(x_net, x) tensor([1., 1., 1.])
- Parameters:
multiscale (bool) – if
True
, computes the multiscale SSIM. Default:False
.max_pixel (float) – maximum pixel value. If None, uses max pixel value of x.
torchmetric_kwargs (dict) – kwargs for torchmetrics SSIM as dict. See https://lightning.ai/docs/torchmetrics/stable/image/structural_similarity.html
complex_abs (bool) – perform complex magnitude before passing data to metric function. If
True
, the data must either be of complex dtype or have size 2 in the channel dimension (usually the second dimension after batch).train_loss (bool) – use metric as a training loss, by returning one minus the metric.
reduction (str) – a method to reduce metric score over individual batch scores.
mean
: takes the mean,sum
takes the sum,none
or None no reduction will be applied (default).norm_inputs (str) – normalize images before passing to metric.
l2``normalizes by L2 spatial norm, ``min_max
normalizes by min and max of each input.