Source code for deepinv.loss.mc

from typing import Union

import torch
from deepinv.loss.loss import Loss
from deepinv.loss.metric.metric import Metric
from deepinv.transform.base import Transform


[docs] class MCLoss(Loss): r""" Measurement consistency loss This loss enforces that the reconstructions are measurement-consistent, i.e., :math:`y=\forw{\inverse{y}}`. The measurement consistency loss is defined as .. math:: \|y-\forw{\inverse{y}}\|^2 where :math:`\inverse{y}` is the reconstructed signal and :math:`A` is a forward operator. By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`) can be used as well. :param Metric, torch.nn.Module metric: metric used for computing data consistency, which is set as the mean squared error by default. """ def __init__(self, metric: Union[Metric, torch.nn.Module] = torch.nn.MSELoss()): super(MCLoss, self).__init__() self.name = "mc" self.metric = metric
[docs] def forward(self, y, x_net, physics, **kwargs): r""" Computes the measurement splitting loss :param torch.Tensor y: measurements. :param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`. :param deepinv.physics.Physics physics: forward operator associated with the measurements. :return: (torch.Tensor) loss. """ return self.metric(physics.A(x_net), y)