TVLoss

class deepinv.loss.TVLoss(weight=1.0)[source]

Bases: Loss

Total variation loss (\(\ell_2\) norm).

It computes the loss \(\|D\hat{x}\|_2^2\), where \(D\) is a normalized linear operator that computes the vertical and horizontal first order differences of the reconstructed image \(\hat{x}\).

Parameters:

weight (float) – scalar weight for the TV loss.

forward(x_net, **kwargs)[source]

Computes the TV loss.

Parameters:

x_net (torch.Tensor) – reconstructed image.

Returns:

torch.nn.Tensor loss of size (batch_size,)