MultiCoilMRI#
- class deepinv.physics.MultiCoilMRI(mask=None, coil_maps=None, img_size=(320, 320), three_d=False, device=torch.device('cpu'), **kwargs)[source]#
Bases:
MRIMixin
,LinearPhysics
Multi-coil 2D or 3D MRI operator.
The linear operator operates in 2D slices or 3D volumes and is defined as:
\[y_n = \text{diag}(p) F \text{diag}(s_n) x\]for \(n=1,\dots,N\) coils, where \(y_n\) are the measurements from the cth coil, \(\text{diag}(p)\) is the acceleration mask, \(F\) is the Fourier transform and \(\text{diag}(s_n)\) is the nth coil sensitivity.
The data
x
should be of shape (B,C,H,W) or (B,C,D,H,W) where C=2 is the channels (real and imaginary) and D is optional dimension for 3D MRI. Then, the resulting measurementsy
will be of shape (B,C,N,(D,)H,W) where N is the coils dimension.Note
We provide various random mask generators (e.g. Cartesian undersampling) that can be used directly with this physics. See e.g.
deepinv.physics.generator.mri.RandomMaskGenerator
. If mask or coil maps are not passed, a mask and maps full of ones is used (i.e. no acceleration).Note
You can also simulate basic
birdcage coil sensitivity maps <https://mriquestions.com/birdcage-coil.html>
by passing instead an integer tocoil_maps
usingMultiCoilMRI(coil_maps=N, img_size=x.shape)
(note this requires installing thesigpy
library).Note
This physics is directly compatible with FastMRI data using
deepinv.datasets.FastMRISliceDataset
. The dataset loads pairs of RSS images and multicoil kspace(x, y)
wherex = MultiCoilMRI().A_adjoint(y, rss=True, crop=True)
.- Parameters:
mask (torch.Tensor) – binary sampling mask which should have shape (H,W), (C,H,W), (B,C,H,W), or (B,C,…,H,W). If None, generate mask of ones with
img_size
.coil_maps (torch.Tensor, str) – either
Tensor
, integer, orNone
. If complex valued (i.e. of complex dtype) coil sensitvity maps which should have shape (H,W), (N,H,W), (B,N,H,W) or (B,N,…,H,W). If None, generate flat coil maps of ones withimg_size
. If integer, simulate birdcage coil maps with integer number of coils (this requiressigpy
installed).img_size (tuple) – if
mask
orcoil_maps
not specified, flatmask
orcoil_maps
of ones are created usingimg_size
, whereimg_size
can be of any shape specified above. Ifmask
orcoil_maps
provided,img_size
is ignored.three_d (bool) – if
True
, calculate Fourier transform in 3D for 3D data (i.e. data of shape (B,C,D,H,W) where D is depth).device (torch.device, str) – specify which device you want to use (i.e, cpu or gpu).
- Examples:
Multi-coil MRI operator:
>>> from deepinv.physics import MultiCoilMRI >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn(1, 2, 2, 2) # Define random 2x2 image B,C,H,W >>> physics = MultiCoilMRI(img_size=x.shape) # Define coil map of ones >>> physics(x).shape # B,C,N,H,W torch.Size([1, 2, 1, 2, 2]) >>> coil_maps = torch.randn(1, 5, 2, 2, dtype=torch.complex64) # Define 5-coil sensitivity maps >>> physics.update_parameters(coil_maps=coil_maps) # Update coil maps on the fly >>> physics(x).shape torch.Size([1, 2, 5, 2, 2])
- A(x, mask=None, coil_maps=None, **kwargs)[source]#
Applies linear operator.
Optionally update MRI mask or coil sensitivity maps on the fly.
- Parameters:
x (torch.Tensor) – image with shape
(B,2,...,H,W)
.mask (torch.Tensor) – optionally set the mask on-the-fly.
coil_maps (torch.Tensor) – optionally set the mask on-the-fly.
- Returns:
(
torch.Tensor
) multi-coil kspace measurements with shape(B,2,N,...,H,W)
whereN
is coil dimension.
- A_adjoint(y, mask=None, coil_maps=None, rss=False, crop=False, **kwargs)[source]#
Applies adjoint linear operator.
Optionally update MRI mask or coil sensitivity maps on the fly.
- Parameters:
y (torch.Tensor) – multi-coil kspace measurements with shape [B,2,N,…,H,W] where N is coil dimension.
mask (torch.Tensor) – optionally set the mask on-the-fly.
coil_maps (torch.Tensor) – optionally set the mask on-the-fly.
rss (bool) – perform root-sum-square reconstruction. This option is provided to match the original data of
deepinv.datasets.FastMRISliceDataset
, such thatx = MultiCoilMRI().A_adjoint(y, rss=True)
.crop (bool) – if
True
, crop last 2 dims of x to last 2 dims of img_size. This option is provided to match the original data ofdeepinv.datasets.FastMRISliceDataset
, such thatx = MultiCoilMRI().A_adjoint(y, crop=True)
.
- Returns:
(
torch.Tensor
) image with shape(B,2,...,H,W)
if not rss else(B,1,...,H,W)
- simulate_birdcage_csm(n_coils)[source]#
Simulate birdcage coil sensitivity maps. Requires library
sigpy
.- Parameters:
n_coils (int) – number of coils N
- Return torch.Tensor:
coil maps of complex dtype of shape (N,H,W)
- update_parameters(mask=None, coil_maps=None, check_mask=True, **kwargs)[source]#
Update MRI subsampling mask and coil sensitivity maps.
- Parameters:
mask (torch.nn.parameter.Parameter, torch.Tensor) – MRI mask
coil_maps (torch.nn.parameter.Parameter, torch.Tensor) – MRI coil sensitivity maps
check_mask (bool) – check mask dimensions before updating