Phase2PhaseLoss#

class deepinv.loss.Phase2PhaseLoss(tensor_size: Tuple[int], dynamic_model: bool = True, metric: Metric | Module = MSELoss(), device='cpu')[source]#

Bases: SplittingLoss

Phase2Phase loss for dynamic data.

Implements dynamic measurement splitting loss from Phase2Phase: Respiratory Motion-Resolved Reconstruction of Free-Breathing Magnetic Resonance Imaging Using Deep Learning Without a Ground Truth for Improved Liver Imaging for free-breathing MRI. This is a special (temporal) case of the generic splitting loss: see deepinv.loss.SplittingLoss for more details.

Splits the dynamic measurements into even time frames (“phases”) at model input and odd phases to use for constructing the loss. Equally, the physics mask (if it exists) is split as well: the even phases are used for the model (e.g. for data consistency in an unrolled network) and odd phases are used for the reference. At test time, the full input is passed through the network.

Warning

The model should be adapted before training using the method adapt_model() to include the splitting mechanism at the input.

Warning

Must only be used for dynamic or sequential measurements, i.e. where data \(y\) and physics.mask (if it exists) are of 5D shape (B, C, T, H, W).

Note

Phase2Phase can be used to reconstruct video sequences by setting dynamic_model=True and using physics deepinv.physics.DynamicMRI. It can also be used to reconstructs static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled spokes such that the whole measurement is a set of non-overlapping spokes. To do this, set dynamic_model=False and use physics deepinv.physics.SequentialMRI. See below for example or Self-supervised MRI reconstruction with Artifact2Artifact for full MRI example.

By default, the error is computed using the MSE metric, however any appropriate metric can be used.

Parameters:
  • tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension of shape (C, T, H, W)

  • dynamic_model (bool) – set True if using with a model that inputs and outputs time-data i.e. x of shape (B,C,T,H,W). Set False if x are static images (B,C,H,W).

  • metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.

  • device (str, torch.device) – torch device.


Example:

Dynamic MRI with Phase2Phase with a video network:

>>> import torch
>>> from deepinv.models import AutoEncoder, TimeAgnosticNet
>>> from deepinv.physics import DynamicMRI, SequentialMRI
>>> from deepinv.loss import Phase2PhaseLoss
>>>
>>> x = torch.rand((1, 2, 4, 4, 4)) # B, C, T, H, W
>>> mask = torch.zeros((1, 2, 4, 4, 4))
>>> mask[:, :, torch.arange(4), torch.arange(4) % 4, :] = 1 # Create time-varying mask
>>>
>>> physics = DynamicMRI(mask=mask)
>>> loss = Phase2PhaseLoss((2, 4, 4, 4))
>>> model = TimeAgnosticNet(AutoEncoder(32, 2, 2)) # Example video network
>>> model = loss.adapt_model(model) # Adapt model to perform Phase2Phase
>>>
>>> y = physics(x)
>>> x_net = model(y, physics, update_parameters=True) # save random mask in forward pass
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True

Free-breathing MRI with Phase2Phase with an image network and sequential measurements:

>>> physics = SequentialMRI(mask=mask) # mask is B, C, T, H, W
>>> loss = Phase2PhaseLoss((2, 4, 4, 4), dynamic_model=False) # Process static images x
>>>
>>> model = AutoEncoder(32, 2, 2) # Example image reconstruction network
>>> model = loss.adapt_model(model) # Adapt model to perform Phase2Phase
>>>
>>> x = torch.rand((1, 2, 4, 4)) # B, C, H, W
>>> y = physics(x) # B, C, T, H, W
>>> x_net = model(y, physics, update_parameters=True)
>>> l = loss(x_net, y, physics, model)
>>> print(l.item() > 0)
True
adapt_model(model: Module) Module[source]#

Apply Phase2Phase splitting to model input. Also perform time-averaging if a static model is used.

Parameters:

model (torch.nn.Module) – Reconstruction model.

Returns:

(torch.nn.Module) Model modified for evaluation.

static split(mask: Tensor, y: Tensor, physics: Physics | None = None)[source]#

Override splitting to actually remove masked pixels. In Phase2Phase, this corresponds to masked phases (i.e. time steps).

Parameters:

Examples using Phase2PhaseLoss:#

Self-supervised MRI reconstruction with Artifact2Artifact

Self-supervised MRI reconstruction with Artifact2Artifact