ReducedResolutionLoss#
- class deepinv.loss.ReducedResolutionLoss(metric=None, physics=None)[source]#
Bases:
SupLoss
Reduced resolution loss for blur and downsampling problems.
The reduced resolution loss is defined as
\[\frac{1}{n}\|y-\inverse{\forw{y}}\|^2\]where \(\forw{y}\) is the reduced resolution measurement via further degrading, and the measurement \(y\) is used a supervisory signal.
Note
Optionally initialize with physics to fix the reduced resolution operator. If not passed, the loss takes the physics from the forward pass during training. However, this should only be used with physics that can be used to meaningfully further degrade the measurements \(y\), such as blur or downsampling. The physics must be defined without an
img_size
so it can be applied to the measurements \(y\).At test time, the model does not perform the reduced resolution measurement.
Hint
During training, consider using the
disable_train_metrics
option indeepinv.Trainer
to prevent a shape mismatch during metric computation if the reduced resolution output will be smaller than ground truth.This loss was used in Shocher et al.[1] for downsampling tasks, and is named Wald’s protocol [2] for pan-sharpening tasks.
- Parameters:
metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
physics (Physics) – optional physics to perform reduced resolution measurement. If not specified, take the physics from the forward pass.
- References:
- forward(x_net, y, *args, **kwargs)[source]#
Computes the reduced resolution loss.
- Parameters:
x_net (torch.Tensor) – reconstructions.
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction function.
- Returns:
(
torch.Tensor
) loss.