MRIMixin#
- class deepinv.physics.MRIMixin[source]#
Bases:
object
Mixin base class for MRI functionality.
Base class that provides helper functions for FFT and mask checking.
- check_mask(mask=None, three_d=False, device='cpu', **kwargs)[source]#
Updates MRI mask and verifies mask shape to be B,C,…,H,W where C=2.
- Parameters:
mask (torch.nn.parameter.Parameter, torch.Tensor) – MRI subsampling mask.
three_d (bool) – If
False
the mask should be min 4 dimensions (B, C, H, W) for 2D data, otherwise ifTrue
the mask should have 5 dimensions (B, C, D, H, W) for 3D data.device (torch.device, str) – mask intended device.
- crop(x, crop=True)[source]#
Center crop 2D image according to
img_size
.This matches the RSS reconstructions of the original raw data in
deepinv.datasets.FastMRISliceDataset
.If
img_size
has odd height, then adjust by one pixel to match FastMRI data.- Parameters:
x (torch.Tensor) – input tensor of shape (…,H,W)
crop (bool) – whether to perform crop, defaults to True
- static fft(x, dim=(-2, -1), norm='ortho')[source]#
Centered, orthogonal fft
- Parameters:
x (torch.Tensor) – input image of complex dtype of shape [B,…] where … is all dims to be transformed
dim (tuple) – fft transform dims, defaults to (-2, -1)
norm (str) – fft norm, see docs for
torch.fft.fftn()
, defaults to “ortho”
- static ifft(x, dim=(-2, -1), norm='ortho')[source]#
Centered, orthogonal ifft
- Parameters:
x (torch.Tensor) – input kspace of complex dtype of shape [B,…] where … is all dims to be transformed
dim (tuple) – fft transform dims, defaults to (-2, -1)
norm (str) – fft norm, see docs for
torch.fft.fftn()
, defaults to “ortho”
- im_to_kspace(x, three_d=False)[source]#
Convenience method that wraps fft.
- Parameters:
x (torch.Tensor) – input image of shape (B,2,…) of real dtype
three_d (bool) – whether MRI data is 3D or not, defaults to False
- Returns:
Tensor: output measurements of shape (B,2,…) of real dtype
- Return type:
- kspace_to_im(y, three_d=False)[source]#
Convenience method that wraps inverse fft.
- Parameters:
y (torch.Tensor) – input measurements of shape (B,2,…) of real dtype
three_d (bool) – whether MRI data is 3D or not, defaults to False
- Returns:
Tensor: output image of shape (B,2,…) of real dtype
- Return type:
- static rss(x, multicoil=True, three_d=False)[source]#
Perform root-sum-square reconstruction on multicoil data, defined as
\[\operatorname{RSS}(x) = \sqrt{\sum_{n=1}^N |x_n|^2}\]where \(x_n\) are the coil images of \(x\), \(|\cdot|\) denotes the magnitude and \(N\) is the number of coils. Note that the sum is performed voxel-wise.
- Parameters:
x (torch.Tensor) – input image of shape (B,2,…) where 2 represents real and imaginary channels
multicoil (bool) – if
True
, assumex
is of shape (B,2,N,…), and reduce over coil dimension N too.
Examples using MRIMixin
:#
data:image/s3,"s3://crabby-images/d7afc/d7afc03c8eecf55186146aa64fc9b8b8c6ae6314" alt=""
Self-supervised MRI reconstruction with Artifact2Artifact
data:image/s3,"s3://crabby-images/fa0fd/fa0fd65c41a28e3bb54cda3177d7bd8a45e03396" alt=""
Self-supervised learning with Equivariant Imaging for MRI.