Phase2PhaseLoss
- class deepinv.loss.Phase2PhaseLoss(tensor_size: Tuple[int], dynamic_model: bool = True, metric: Metric | Module = MSELoss(), device='cpu')[source]
Bases:
SplittingLoss
Phase2Phase loss for dynamic data.
Implements dynamic measurement splitting loss from Phase2Phase: Respiratory Motion-Resolved Reconstruction of Free-Breathing Magnetic Resonance Imaging Using Deep Learning Without a Ground Truth for Improved Liver Imaging for free-breathing MRI. This is a special (temporal) case of the generic splitting loss: see
deepinv.loss.SplittingLoss
for more details.Splits the dynamic measurements into even time frames (“phases”) at model input and odd phases to use for constructing the loss. Equally, the physics mask (if it exists) is split as well: the even phases are used for the model (e.g. for data consistency in an unrolled network) and odd phases are used for the reference. At test time, the full input is passed through the network.
Warning
The model should be adapted before training using the method
adapt_model()
to include the splitting mechanism at the input.Warning
Must only be used for dynamic or sequential measurements, i.e. where data \(y\) and
physics.mask
(if it exists) are of 5D shape (B, C, T, H, W).Note
Phase2Phase can be used to reconstruct video sequences by setting
dynamic_model=True
and using physicsdeepinv.physics.DynamicMRI
. It can also be used to reconstructs static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled spokes such that the whole measurement is a set of non-overlapping spokes. To do this, setdynamic_model=False
and use physicsdeepinv.physics.SequentialMRI
. See below for example or Self-supervised MRI reconstruction with Artifact2Artifact for full MRI example.By default, the error is computed using the MSE metric, however any appropriate metric can be used.
- Parameters:
tensor_size (tuple[int]) – size of the tensor to be masked without batch dimension of shape (C, T, H, W)
dynamic_model (bool) – set
True
if using with a model that inputs and outputs time-data i.e.x
of shape (B,C,T,H,W). SetFalse
ifx
are static images (B,C,H,W).metric (Metric, torch.nn.Module) – metric used for computing data consistency, which is set as the mean squared error by default.
device (str, torch.device) – torch device.
- Example:
Dynamic MRI with Phase2Phase with a video network:
>>> import torch >>> from deepinv.models import AutoEncoder, TimeAgnosticNet >>> from deepinv.physics import DynamicMRI, SequentialMRI >>> from deepinv.loss import Phase2PhaseLoss >>> >>> x = torch.rand((1, 2, 4, 4, 4)) # B, C, T, H, W >>> mask = torch.zeros((1, 2, 4, 4, 4)) >>> mask[:, :, torch.arange(4), torch.arange(4) % 4, :] = 1 # Create time-varying mask >>> >>> physics = DynamicMRI(mask=mask) >>> loss = Phase2PhaseLoss((2, 4, 4, 4)) >>> model = TimeAgnosticNet(AutoEncoder(32, 2, 2)) # Example video network >>> model = loss.adapt_model(model) # Adapt model to perform Phase2Phase >>> >>> y = physics(x) >>> x_net = model(y, physics, update_parameters=True) # save random mask in forward pass >>> l = loss(x_net, y, physics, model) >>> print(l.item() > 0) True
Free-breathing MRI with Phase2Phase with an image network and sequential measurements:
>>> physics = SequentialMRI(mask=mask) # mask is B, C, T, H, W >>> loss = Phase2PhaseLoss((2, 4, 4, 4), dynamic_model=False) # Process static images x >>> >>> model = AutoEncoder(32, 2, 2) # Example image reconstruction network >>> model = loss.adapt_model(model) # Adapt model to perform Phase2Phase >>> >>> x = torch.rand((1, 2, 4, 4)) # B, C, H, W >>> y = physics(x) # B, C, T, H, W >>> x_net = model(y, physics, update_parameters=True) >>> l = loss(x_net, y, physics, model) >>> print(l.item() > 0) True
- adapt_model(model: Module) Module [source]
Apply Phase2Phase splitting to model input. Also perform time-averaging if a static model is used.
- Parameters:
model (torch.nn.Module) – Reconstruction model.
- Returns:
(torch.nn.Module) Model modified for evaluation.
- static split(mask: Tensor, y: Tensor, physics: Physics | None = None)[source]
Override splitting to actually remove masked pixels. In Phase2Phase, this corresponds to masked phases (i.e. time steps).
- Parameters:
mask (torch.Tensor) – Phase2Phase mask
y (torch.Tensor) – input data
physics (deepinv.physics.Physics) – forward physics
Examples using Phase2PhaseLoss
:
Self-supervised MRI reconstruction with Artifact2Artifact