ReducedResolutionLoss#
- class deepinv.loss.ReducedResolutionLoss(metric=None)[source]#
Bases:
SupLoss
Reduced resolution loss for blur and downsampling problems.
The reduced resolution loss is defined as
\[\frac{1}{n}\|y-\inverse{\forw{y}}\|^2\]where \(\forw{y}\) is the reduced resolution measurement via further degrading, and the measurement \(y\) is used a supervisory signal.
Warning
This loss can only be used with physics that can be used to meaningfully further degrade the measurements \(y\), such as blur or downsampling. The physics must be defined without an
img_size
so it can be applied to the measurements \(y\).Hint
During training, consider using the
disable_train_metrics
option indeepinv.Trainer
to prevent a shape mismatch during metric computation since the reduced resolution output will smaller than ground truth.This loss was used in Shocher et al.[1] for downsampling tasks, and is named Wald’s protocol [2] for pan-sharpening tasks.
- Parameters:
metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
- References:
- forward(x_net, y, *args, **kwargs)[source]#
Computes the reduced resolution loss.
- Parameters:
x_net (torch.Tensor) – reconstructions.
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction function.
- Returns:
(
torch.Tensor
) loss.