VarNet#
- class deepinv.models.VarNet(denoiser=None, sensitivity_model=None, num_cascades=12, mode='varnet')[source]#
Bases:
ArtifactRemoval
,MRIMixin
VarNet or E2E-VarNet model.
These models are from the papers Sriram et al., End-to-End Variational Networks for Accelerated MRI Reconstruction and Hammernik et al., Learning a variational network for reconstruction of accelerated MRI data.
This performs unrolled iterations on the image estimate x (as per the original VarNet paper) or the kspace y (as per E2E-VarNet).
Note
For singlecoil MRI, either mode is valid. For multicoil MRI, the VarNet mode will simply sum over the coils (not preferred). Using E2E-VarNet is therefore preferred. For sensitivity-map estimation for multicoil MRI, pass in
sensitivity_model
.Code loosely adapted from E2E-VarNet implementation from facebookresearch/fastMRI.
- Parameters:
denoiser (Denoiser, torch.nn.Module) – backbone network that parametrises the grad of the regulariser. If
None
, a small DnCNN is used.sensitivity_model (torch.nn.Module) – network to jointly estimate coil sensitivity maps for multi-coil MRI. If
None
, do not perform any map estimation. For single-coil MRI, unused.num_cascades (int) – number of unrolled iterations (‘cascades’).
mode (str) – if ‘varnet’, perform iterates on the images x as in original VarNet. If ‘e2e-varnet’, perform iterates on the kspace y as in the E2E-VarNet.
- backbone_inference(tensor_in, physics, y)[source]#
Perform inference on input tensor.
Uses physics and y for data consistency. If necessary, perform fully-sampled MRI IFFT on model output.
- Parameters:
tensor_in (torch.Tensor) – input tensor as dictated by VarNet mode (either k-space or image)
physics (Physics) – forward physics for data consistency
y (torch.Tensor) – input measurements y for data consistency
- Returns:
(
torch.Tensor
) reconstructed image- Return type: