AugmentConsistencyLoss#
- class deepinv.loss.AugmentConsistencyLoss(T_i=None, T_e=None, metric=torch.nn.MSELoss(), no_grad=True, rng=None, *args, **kwargs)[source]#
Bases:
Loss
Data augmentation consistency (DAC) loss.
Performs data augmentation in measurement domain as proposed by VORTEX: Physics-Driven Data Augmentations Using Consistency Training for Robust Accelerated MRI Reconstruction.
The loss is defined as follows:
\(\mathcal{L}(T_e\inverse{y,A},\inverse{T_i y,A T_e^{-1}})\)
where \(T_i\) is a
deepinv.transform.Transform
for which we should learn an invariant mapping, and \(T_e\) is adeepinv.transform.Transform
for which we should learn an equivariant mapping.Note
If \(T_e\) is specified, the mapping is performed in the image domain and the model is assumed to take \(A^\top y\) as input.
By default, for \(T_i\) we add random noise
deepinv.transform.RandomNoise
and random phase errordeepinv.transform.RandomPhaseError
. By default, for \(T_e\) we use random shiftdeepinv.transform.Shift
and random rotatesdeepinv.transform.Rotate
.Note
See Transforms for a guide on all available transforms, and how to compose them. For example, you can easily compose further transforms such as
Rotate(rng=rng, multiples=90) | Scale(factors=[0.75, 1.25], rng=rng) | Reflect(rng=rng)
.- Parameters:
T_i (deepinv.transform.Transform) – invariant transform performed on \(y\).
T_e (deepinv.transform.Transform) – equivariant transform performed on \(A^\top y\).
metric (deepinv.loss.metric.Metric, torch.nn.Module) – metric for calculating loss.
no_grad (bool) – if
True
, only propagate gradients through augmented branch as per original paper, ifFalse
, propagate through both branches.rng (torch.Generator) – torch random number generator to pass to transforms.
- forward(x_net, y, physics, model, **kwargs)[source]#
Data augmentation consistency loss forward pass.
- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\inverse{y}\).
y (torch.Tensor) – Measurement.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction function.
- Returns:
(
torch.Tensor
) loss, the tensor size might be (1,) or (batch size,).