Artifact2ArtifactLoss
- class deepinv.loss.Artifact2ArtifactLoss(tensor_size: Tuple[int], split_size: int | Tuple[int] = 2, dynamic_model: bool = True, metric: Metric | Module = MSELoss(), device='cpu')[source]
Bases:
Phase2PhaseLoss
Artifact2Artifact loss for dynamic data.
Implements dynamic measurement splitting loss from RARE: Image Reconstruction using Deep Priors Learned without Ground Truth for free-breathing MRI. This is a special case of the generic splitting loss: see
deepinv.loss.SplittingLoss
for more details.At model input, choose a random time-chunk from the dynamic measurements (“Artifact…”), and another random chunk for constructing the loss (”…2Artifact”). Equally, the physics mask (if it exists) is split as well: the input chunk is used for the model (e.g. for data consistency in an unrolled network) and the output chunk is used as the reference. At test time, the full input is passed through the network. Note this implementation performs a Monte-Carlo-style version where the network output is only compared to one other chunk per iteration.
Warning
The model should be adapted before training using the method
adapt_model()
to include the splitting mechanism at the input.Warning
Must only be used for dynamic or sequential measurements, i.e. where data \(y\) and
physics.mask
(if it exists) are of 5D shape (B, C, T, H, W).Note
Artifact2Artifact can be used to reconstruct video sequences by setting
dynamic_model=True
and using physicsdeepinv.physics.DynamicMRI
. It can also be used to reconstructs static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled spokes such that the whole measurement is a set of non-overlapping spokes. To do this, setdynamic_model=False
and use physicsdeepinv.physics.SequentialMRI
. See below for example or Self-supervised MRI reconstruction with Artifact2Artifact for full MRI example.By default, the error is computed using the MSE metric, however any appropriate metric can be used.
- Parameters:
tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension of shape (C, T, H, W)
split_size (int, tuple[int]) – time-length of chunk. Must divide
tensor_size[1]
exactly. Iftuple
, one is randomly selected each time.dynamic_model (bool) – set True if using with a model that inputs and outputs time-data i.e. x of shape (B,C,T,H,W). Set False if x are static images (B,C,H,W).
metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
device (str, torch.device) – torch device.
- Example:
Dynamic MRI with Artifact2Artifact with a video network:
>>> import torch >>> from deepinv.models import AutoEncoder, TimeAgnosticNet >>> from deepinv.physics import DynamicMRI, SequentialMRI >>> from deepinv.loss import Artifact2ArtifactLoss >>> >>> x = torch.rand((1, 2, 4, 4, 4)) # B, C, T, H, W >>> mask = torch.zeros((1, 2, 4, 4, 4)) >>> mask[:, :, torch.arange(4), torch.arange(4) % 4, :] = 1 # Create time-varying mask >>> >>> physics = DynamicMRI(mask=mask) >>> loss = Artifact2ArtifactLoss((2, 4, 4, 4)) >>> model = TimeAgnosticNet(AutoEncoder(32, 2, 2)) # Example video network >>> model = loss.adapt_model(model) # Adapt model to perform Artifact2Artifact >>> >>> y = physics(x) >>> x_net = model(y, physics, update_parameters=True) # save random mask in forward pass >>> l = loss(x_net, y, physics, model) >>> print(l.item() > 0) True
Free-breathing MRI with Artifact2Artifact with an image network and sequential measurements:
>>> physics = SequentialMRI(mask=mask) # mask is B, C, T, H, W >>> loss = Artifact2ArtifactLoss((2, 4, 4, 4), dynamic_model=False) # Process static images x >>> >>> model = AutoEncoder(32, 2, 2) # Example image reconstruction network >>> model = loss.adapt_model(model) # Adapt model to perform Artifact2Artifact >>> >>> x = torch.rand((1, 2, 4, 4)) # B, C, H, W >>> y = physics(x) # B, C, T, H, W >>> x_net = model(y, physics, update_parameters=True) >>> l = loss(x_net, y, physics, model) >>> print(l.item() > 0) True
- forward(x_net, y, physics, model, **kwargs)[source]
Computes the measurement splitting loss
- Parameters:
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction function.
- Returns:
(torch.Tensor) loss.
Examples using Artifact2ArtifactLoss
:
Self-supervised MRI reconstruction with Artifact2Artifact