AugmentConsistencyLoss#

class deepinv.loss.AugmentConsistencyLoss(T_i=None, T_e=None, metric=torch.nn.MSELoss(), no_grad=True, rng=None, *args, **kwargs)[source]#

Bases: Loss

Data augmentation consistency (DAC) loss.

Performs data augmentation in measurement domain as proposed by VORTEX: Physics-Driven Data Augmentations Using Consistency Training for Robust Accelerated MRI Reconstruction.

The loss is defined as follows:

\(\mathcal{L}(T_e\inverse{y,A},\inverse{T_i y,A T_e^{-1}})\)

where \(T_i\) is a deepinv.transform.Transform for which we should learn an invariant mapping, and \(T_e\) is a deepinv.transform.Transform for which we should learn an equivariant mapping.

Note

If \(T_e\) is specified, the mapping is performed in the image domain and the model is assumed to take \(A^\top y\) as input.

By default, for \(T_i\) we add random noise deepinv.transform.RandomNoise and random phase error deepinv.transform.RandomPhaseError. By default, for \(T_e\) we use random shift deepinv.transform.Shift and random rotates deepinv.transform.Rotate.

Note

See Transforms for a guide on all available transforms, and how to compose them. For example, you can easily compose further transforms such as Rotate(rng=rng, multiples=90) | Scale(factors=[0.75, 1.25], rng=rng) | Reflect(rng=rng).

Parameters:
forward(x_net, y, physics, model, **kwargs)[source]#

Data augmentation consistency loss forward pass.

Parameters:
Returns:

(torch.Tensor) loss, the tensor size might be (1,) or (batch size,).