TVLoss#
- class deepinv.loss.TVLoss(weight=1.0)[source]#
Bases:
Loss
Total variation loss (\(\ell_2\) norm).
It computes the loss \(\|D\hat{x}\|_2^2\), where \(D\) is a normalized linear operator that computes the vertical and horizontal first order differences of the reconstructed image \(\hat{x}\).
- Parameters:
weight (float) – scalar weight for the TV loss.
- forward(x_net, **kwargs)[source]#
Computes the TV loss.
- Parameters:
x_net (torch.Tensor) – reconstructed image.
- Returns:
torch.nn.Tensor loss of size (batch_size,)