Artifact2ArtifactLoss#

class deepinv.loss.Artifact2ArtifactLoss(tensor_size: Tuple[int], split_size: int | Tuple[int] = 2, dynamic_model: bool = True, metric: Metric | Module = MSELoss(), device='cpu')[source]#

Bases: Phase2PhaseLoss

Artifact2Artifact loss for dynamic data.

Implements dynamic measurement splitting loss from RARE: Image Reconstruction using Deep Priors Learned without Ground Truth for free-breathing MRI. This is a special case of the generic splitting loss: see deepinv.loss.SplittingLoss for more details.

At model input, choose a random time-chunk from the dynamic measurements (“Artifact…”), and another random chunk for constructing the loss (”…2Artifact”). Equally, the physics mask (if it exists) is split as well: the input chunk is used for the model (e.g. for data consistency in an unrolled network) and the output chunk is used as the reference. At test time, the full input is passed through the network. Note this implementation performs a Monte-Carlo-style version where the network output is only compared to one other chunk per iteration.

Warning

The model should be adapted before training using the method adapt_model() to include the splitting mechanism at the input.

Warning

Must only be used for dynamic or sequential measurements, i.e. where data \(y\) and physics.mask (if it exists) are of 5D shape (B, C, T, H, W).

Note

Artifact2Artifact can be used to reconstruct video sequences by setting dynamic_model=True and using physics deepinv.physics.DynamicMRI. It can also be used to reconstructs static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled spokes such that the whole measurement is a set of non-overlapping spokes. To do this, set dynamic_model=False and use physics deepinv.physics.SequentialMRI. See below for example or Self-supervised MRI reconstruction with Artifact2Artifact for full MRI example.

By default, the error is computed using the MSE metric, however any appropriate metric can be used.

Parameters:
  • tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension of shape (C, T, H, W)

  • split_size (int, tuple[int]) – time-length of chunk. Must divide tensor_size[1] exactly. If tuple, one is randomly selected each time.

  • dynamic_model (bool) – set True if using with a model that inputs and outputs time-data i.e. x of shape (B,C,T,H,W). Set False if x are static images (B,C,H,W).

  • metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.

  • device (str, torch.device) – torch device.


Example:

Dynamic MRI with Artifact2Artifact with a video network:

>>> import torch
>>> from deepinv.models import AutoEncoder, TimeAgnosticNet
>>> from deepinv.physics import DynamicMRI, SequentialMRI
>>> from deepinv.loss import Artifact2ArtifactLoss
>>>
>>> x = torch.rand((1, 2, 4, 4, 4)) # B, C, T, H, W
>>> mask = torch.zeros((1, 2, 4, 4, 4))
>>> mask[:, :, torch.arange(4), torch.arange(4) % 4, :] = 1 # Create time-varying mask
>>>
>>> physics = DynamicMRI(mask=mask)
>>> loss = Artifact2ArtifactLoss((2, 4, 4, 4))
>>> model = TimeAgnosticNet(AutoEncoder(32, 2, 2)) # Example video network
>>> model = loss.adapt_model(model) # Adapt model to perform Artifact2Artifact
>>>
>>> y = physics(x)
>>> x_net = model(y, physics, update_parameters=True) # save random mask in forward pass
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True

Free-breathing MRI with Artifact2Artifact with an image network and sequential measurements:

>>> physics = SequentialMRI(mask=mask) # mask is B, C, T, H, W
>>> loss = Artifact2ArtifactLoss((2, 4, 4, 4), dynamic_model=False) # Process static images x
>>>
>>> model = AutoEncoder(32, 2, 2) # Example image reconstruction network
>>> model = loss.adapt_model(model) # Adapt model to perform Artifact2Artifact
>>>
>>> x = torch.rand((1, 2, 4, 4)) # B, C, H, W
>>> y = physics(x) # B, C, T, H, W
>>> x_net = model(y, physics, update_parameters=True)
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True
forward(x_net, y, physics, model, **kwargs)[source]#

Computes the measurement splitting loss

Parameters:
Returns:

(torch.Tensor) loss.

Examples using Artifact2ArtifactLoss:#

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised MRI reconstruction with Artifact2Artifact