Source code for deepinv.physics.mri

from __future__ import annotations
from warnings import warn
import numpy as np
import torch
from torch import Tensor

from deepinv.physics.forward import DecomposablePhysics, LinearPhysics
from deepinv.utils.mixins import MRIMixin, TimeMixin


[docs] class MRI(MRIMixin, DecomposablePhysics): r""" Single-coil accelerated 2D or 3D magnetic resonance imaging. The linear operator operates in 2D slices or 3D volumes and is defined as .. math:: y = MFx where :math:`M` applies a mask (subsampling operator), and :math:`F` is the 2D or 3D discrete Fourier Transform. This operator has a simple singular value decomposition, so it inherits the structure of :class:`deepinv.physics.DecomposablePhysics` and thus have a fast pseudo-inverse and prox operators. The complex images :math:`x` and measurements :math:`y` should be of size (B, C,..., H, W) with C=2, where the first channel corresponds to the real part and the second channel corresponds to the imaginary part. The ``...`` is an optional depth dimension for 3D MRI data. A fixed mask can be set at initialisation, or a new mask can be set either at forward (using ``physics(x, mask=mask)``) or using ``update``. .. note:: We provide various random mask generators (e.g. Cartesian undersampling) that can be used directly with this physics. See e.g. :class:`deepinv.physics.generator.mri.RandomMaskGenerator` If mask is not passed, a mask full of ones is used (i.e. no acceleration). .. note:: This physics is directly compatible with FastMRI data using :class:`deepinv.datasets.FastMRISliceDataset`. The dataset loads pairs of magnitude images and kspace ``(x, y)`` where ``x = MRI().A_adjoint(y, mag=True, crop=True)``. :param torch.Tensor mask: binary mask, where 1s represent sampling locations, and 0s otherwise. The mask size can either be (H,W), (C,H,W), (B,C,H,W), (B,C,...,H,W) where H, W are the image height and width, C is channels (which should be 2) and B is batch size. :param tuple img_size: if mask not specified, flat mask of ones is created using ``img_size``, where ``img_size`` can be of any shape specified above. If mask provided, ``img_size`` is ignored. :param bool three_d: if ``True``, calculate Fourier transform in 3D for 3D data (i.e. data of shape (B,C,D,H,W) where D is depth). :param torch.device device: cpu or gpu. |sep| :Examples: Single-coil accelerated MRI operator with subsampling mask: >>> from deepinv.physics import MRI >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn(1, 2, 2, 2) # Define random 2x2 image >>> mask = 1 - torch.eye(2) # Define subsampling mask >>> physics = MRI(mask=mask) # Define mask at initialisation >>> physics(x) tensor([[[[ 0.0000, -1.4290], [ 0.4564, -0.0000]], <BLANKLINE> [[ 0.0000, 1.8622], [ 0.0603, -0.0000]]]]) >>> physics = MRI(img_size=x.shape) # No subsampling >>> physics(x) tensor([[[[ 2.2908, -1.4290], [ 0.4564, -0.1814]], <BLANKLINE> [[ 0.3744, 1.8622], [ 0.0603, -0.6209]]]]) >>> physics.update(mask=mask) # Update mask on the fly >>> physics(x) tensor([[[[ 0.0000, -1.4290], [ 0.4564, -0.0000]], <BLANKLINE> [[ 0.0000, 1.8622], [ 0.0603, -0.0000]]]]) """ def __init__( self, mask: Tensor | None = None, img_size: tuple | None = (320, 320), three_d: bool = False, device: torch.device | str = "cpu", **kwargs, ): super().__init__(device=device, **kwargs) self.three_d = three_d self.img_size = img_size if mask is None: mask = torch.ones(*img_size, device=device) # Check and update mask self.register_buffer("mask", self.check_mask(mask)) self.img_size = self.mask.shape[1:] self.to(device) def V_adjoint(self, x: Tensor) -> Tensor: return self.im_to_kspace(x, three_d=self.three_d) def V(self, x: Tensor) -> Tensor: return self.kspace_to_im(x, three_d=self.three_d)
[docs] def A_adjoint( self, y: Tensor, mask: Tensor = None, mag: bool = False, crop: bool = False, **kwargs, ) -> Tensor: """Adjoint operator. Optionally perform crop and magnitude to match FastMRI data. By default, crop and magnitude are not performed. By setting ``mag=crop=True``, the outputs will be consistent with :class:`deepinv.datasets.FastMRISliceDataset`. :param torch.Tensor y: input kspace of shape (B,C,...,H,W) :param torch.Tensor mask: optionally set mask on-the-fly. :param bool mag: perform complex magnitude. This option is provided to match the original data of :class:`deepinv.datasets.FastMRISliceDataset`, such that ``x = MRI().A_adjoint(y, mag=True)``. :param bool crop: if ``True``, crop last 2 dims of x to last 2 dims of img_size. This option is provided to match the original data of :class:`deepinv.datasets.FastMRISliceDataset`, such that ``x = MRI().A_adjoint(y, crop=True)``. """ x = super().A_adjoint(y, mask, **kwargs) if mag: x = self.rss(x, multicoil=False) if crop: x = self.crop(x, crop=crop) return x # (B,C,...,H,W) where C=1 if mag else 2
[docs] def noise(self, x, **kwargs): r""" Incorporates noise into the measurements :math:`\tilde{y} = N(y)` :param torch.Tensor x: clean measurements :return torch.Tensor: noisy measurements """ noise = self.U(self.noise_model(x, **kwargs) * self.mask) return noise
[docs] def update_parameters(self, mask: Tensor = None, check_mask: bool = True, **kwargs): """Update MRI subsampling mask. :param torch.nn.parameter.Parameter, torch.Tensor mask: MRI mask :param bool check_mask: check mask dimensions before updating """ if mask is not None: mask = ( self.check_mask( mask=mask, three_d=getattr(self, "three_d", False), ) if check_mask else mask ) super().update_parameters(mask=mask, **kwargs)
[docs] class MultiCoilMRI(MRIMixin, LinearPhysics): r""" Multi-coil 2D or 3D MRI operator. The linear operator operates in 2D slices or 3D volumes and is defined as: .. math:: y_n = \text{diag}(p) F \text{diag}(s_n) x for :math:`n=1,\dots,N` coils, where :math:`y_n` are the measurements from the cth coil, :math:`\text{diag}(p)` is the acceleration mask, :math:`F` is the Fourier transform and :math:`\text{diag}(s_n)` is the nth coil sensitivity. The data ``x`` should be of shape (B,C,H,W) or (B,C,D,H,W) where C=2 is the channels (real and imaginary) and D is optional dimension for 3D MRI. Then, the resulting measurements ``y`` will be of shape (B,C,N,(D,)H,W) where N is the coils dimension. .. note:: We provide various random mask generators (e.g. Cartesian undersampling) that can be used directly with this physics. See e.g. :class:`deepinv.physics.generator.mri.RandomMaskGenerator`. If mask or coil maps are not passed, a mask and maps full of ones is used (i.e. no acceleration). .. note:: You can also simulate basic `birdcage coil sensitivity maps <https://mriquestions.com/birdcage-coil.html>` by passing instead an integer to ``coil_maps`` using ``MultiCoilMRI(coil_maps=N, img_size=x.shape)`` (note this requires installing the ``sigpy`` library). .. note:: This physics is directly compatible with FastMRI data using :class:`deepinv.datasets.FastMRISliceDataset`. The dataset loads pairs of RSS images and multicoil kspace ``(x, y)`` where ``x = MultiCoilMRI().A_adjoint(y, rss=True, crop=True)``. :param torch.Tensor mask: binary sampling mask which should have shape (H,W), (C,H,W), (B,C,H,W), or (B,C,...,H,W). If None, generate mask of ones with ``img_size``. :param torch.Tensor, str coil_maps: either ``Tensor``, integer, or ``None``. If complex valued (i.e. of complex dtype) coil sensitvity maps which should have shape (H,W), (N,H,W), (B,N,H,W) or (B,N,...,H,W). If None, generate flat coil maps of ones with ``img_size``. If integer, simulate birdcage coil maps with integer number of coils (this requires ``sigpy`` installed). :param tuple img_size: if ``mask`` or ``coil_maps`` not specified, flat ``mask`` or ``coil_maps`` of ones are created using ``img_size``, where ``img_size`` can be of any shape specified above. If ``mask`` or ``coil_maps`` provided, ``img_size`` is ignored. :param bool three_d: if ``True``, calculate Fourier transform in 3D for 3D data (i.e. data of shape (B,C,D,H,W) where D is depth). :param torch.device, str device: specify which device you want to use (i.e, cpu or gpu). |sep| :Examples: Multi-coil MRI operator: >>> from deepinv.physics import MultiCoilMRI >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn(1, 2, 2, 2) # Define random 2x2 image B,C,H,W >>> physics = MultiCoilMRI(img_size=x.shape) # Define coil map of ones >>> physics(x).shape # B,C,N,H,W torch.Size([1, 2, 1, 2, 2]) >>> coil_maps = torch.randn(1, 5, 2, 2, dtype=torch.complex64) # Define 5-coil sensitivity maps >>> physics.update(coil_maps=coil_maps) # Update coil maps on the fly >>> physics(x).shape torch.Size([1, 2, 5, 2, 2]) """ def __init__( self, mask: Tensor | None = None, coil_maps: Tensor | int | None = None, img_size: tuple | None = (320, 320), three_d: bool = False, device: torch.device | str = torch.device("cpu"), **kwargs, ): super().__init__(device=device, **kwargs) self.img_size = img_size self.three_d = three_d if mask is None: mask = torch.ones(*img_size, device=device) if coil_maps is None: coil_maps = torch.ones( (self.img_size[-2:] if not self.three_d else self.img_size[-3:]), dtype=torch.complex64, device=device, ) elif isinstance(coil_maps, int): coil_maps = self.simulate_birdcage_csm(n_coils=coil_maps).to(device) self.register_buffer("mask", self.check_mask(mask, three_d=self.three_d)) self.register_buffer( "coil_maps", self.check_coil_maps(coil_maps, three_d=self.three_d) ) self.to(device)
[docs] def A( self, x: Tensor, mask: Tensor = None, coil_maps: Tensor = None, **kwargs ) -> Tensor: r""" Applies linear operator. Optionally update MRI mask or coil sensitivity maps on the fly. :param torch.Tensor x: image with shape `(B,2,...,H,W)`. :param torch.Tensor mask: optionally set the mask on-the-fly. :param torch.Tensor coil_maps: optionally set the mask on-the-fly. :returns: (:class:`torch.Tensor`) multi-coil kspace measurements with shape `(B,2,N,...,H,W)` where `N` is coil dimension. """ self.update_parameters(mask=mask, coil_maps=coil_maps, **kwargs) Sx = self.coil_maps * self.to_torch_complex(x)[:, None] # [B,N,...,H,W] FSx = self.fft(Sx, dim=(-3, -2, -1) if self.three_d else (-2, -1)) MFSx = self.mask[:, :, None] * self.from_torch_complex(FSx) # [B,2,N,...,H,W] return MFSx
[docs] def noise(self, x, **kwargs) -> Tensor: r""" Incorporates noise into the measurements :math:`\tilde{y} = N(y)` and takes the mask into account. :param torch.Tensor x: clean measurements :param None, float noise_level: optional noise level parameter :return: noisy measurements """ return self.mask[:, :, None] * self.noise_model(x, **kwargs)
[docs] def A_adjoint( self, y: Tensor, mask: Tensor = None, coil_maps: Tensor = None, rss: bool = False, crop: bool = False, **kwargs, ) -> Tensor: r""" Applies adjoint linear operator. Optionally update MRI mask or coil sensitivity maps on the fly. :param torch.Tensor y: multi-coil kspace measurements with shape [B,2,N,...,H,W] where N is coil dimension. :param torch.Tensor mask: optionally set the mask on-the-fly. :param torch.Tensor coil_maps: optionally set the mask on-the-fly. :param bool rss: perform root-sum-square reconstruction. This option is provided to match the original data of :class:`deepinv.datasets.FastMRISliceDataset`, such that ``x = MultiCoilMRI().A_adjoint(y, rss=True)``. :param bool crop: if ``True``, crop last 2 dims of x to last 2 dims of img_size. This option is provided to match the original data of :class:`deepinv.datasets.FastMRISliceDataset`, such that ``x = MultiCoilMRI().A_adjoint(y, crop=True)``. :returns: (:class:`torch.Tensor`) image with shape `(B,2,...,H,W)` if not rss else `(B,1,...,H,W)` """ assert y.shape[1] == 2, "y must be of shape (B,2,N,...,H,W)" self.update_parameters(mask=mask, coil_maps=coil_maps, **kwargs) My = self.to_torch_complex(self.mask[:, :, None] * y) # [B,N,...,H,W] FiMy = self.ifft(My, dim=(-3, -2, -1) if self.three_d else (-2, -1)) if rss: x = self.from_torch_complex(FiMy) x = self.rss(x, multicoil=True) # [B,1,...,H,W] else: # Use conj as coil maps are elementwise multiplication SiFiMy = torch.sum(torch.conj(self.coil_maps) * FiMy, dim=1) # [B,...,H,W] x = self.from_torch_complex(SiFiMy) # [B,2,...,H,W] return self.crop(x, crop=crop)
[docs] def A_dagger( self, y: Tensor, mask: Tensor = None, coil_maps: Tensor = None, **kwargs ) -> Tensor: r""" Computes least squares solution to the MRI inverse problem, as proposed in `SENSE: Sensitivity encoding for fast MRI <https://doi.org/10.1002/(SICI)1522-2594(199911)42:5%3C952::AID-MRM16%3E3.0.CO;2-S>`_. By default uses conjugate gradient solver. Overwrite default solver arguments by passing `kwargs`. See :func:`deepinv.optim.linear.least_squares` for details. The MRI mask or coil sensitivity maps are updated if passed as inputs to the function. :param torch.Tensor y: multi-coil kspace measurements with shape [B,2,N,...,H,W] where N is coil dimension. :param torch.Tensor mask: optionally set the mask on-the-fly. :param torch.Tensor coil_maps: optionally set the mask on-the-fly. :param dict kwargs: kwargs to pass to base :meth:`deepinv.physics.LinearPhysics.A_dagger`. :returns: (:class:`torch.Tensor`) image with shape `(B,2,...,H,W)` """ self.update_parameters(mask=mask, coil_maps=coil_maps) return super().A_dagger(y, **kwargs)
[docs] def update_parameters( self, mask: Tensor = None, coil_maps: Tensor = None, check_mask: bool = True, check_coil_maps: bool = True, **kwargs, ): """Update MRI subsampling mask and coil sensitivity maps. :param torch.nn.parameter.Parameter, torch.Tensor mask: MRI mask :param torch.nn.parameter.Parameter, torch.Tensor coil_maps: MRI coil sensitivity maps :param bool check_mask: check mask dimensions before updating :param bool check_coil_maps: check coil maps dimensions before updating """ if mask is not None: mask = ( self.check_mask(mask=mask, three_d=self.three_d) if check_mask else mask ) if coil_maps is not None: coil_maps = ( self.check_coil_maps(coil_maps, three_d=self.three_d) if check_coil_maps else coil_maps ) super().update_parameters(mask=mask, coil_maps=coil_maps, **kwargs) # Update image size with latest mask shape self.img_size = self.mask.shape[1:] if self.coil_maps is not None and self.coil_maps.shape[2:] != self.img_size[1:]: warn( f"After updating parameters, img_size {self.img_size} in MultiCoilMRI is incompatible with coil_maps shape {coil_maps.shape} in the spatial dims." )
[docs] @staticmethod def check_coil_maps(coil_maps: Tensor, three_d: bool) -> Tensor: """Check coil maps dimensions. :param torch.Tensor coil_maps: coil sensitivity maps :return torch.Tensor: checked coil sensitivity maps """ while len(coil_maps.shape) < ( 4 if not three_d else 5 ): # to B,N,H,W or B,N,D,H,W coil_maps = coil_maps.unsqueeze(0) if not coil_maps.is_complex(): raise ValueError("coil_maps should be of torch complex dtype.") return coil_maps
[docs] def simulate_birdcage_csm(self, n_coils: int) -> Tensor: """Simulate birdcage coil sensitivity maps. Requires library ``sigpy``. :param int n_coils: number of coils N :return torch.Tensor: coil maps of complex dtype of shape (N,H,W) """ try: from sigpy.mri import birdcage_maps except ImportError: # pragma: no cover raise ImportError( "sigpy is required to simulate coil maps. Install it using pip install sigpy" ) coil_maps = birdcage_maps( (n_coils,) + (self.img_size[-2:] if not self.three_d else self.img_size[-3:]) ) return torch.tensor(coil_maps).type(torch.complex64)
[docs] @staticmethod def estimate_coil_maps( y: Tensor, calib_size: int = 24, use_cupy: bool = False ) -> Tensor: """Estimate coil sensitivity maps using ESPIRiT. This was proposed in `ESPIRiT — An Eigenvalue Approach to Autocalibrating Parallel MRI: Where SENSE meets GRAPPA <https://onlinelibrary.wiley.com/doi/10.1002/mrm.24751>`_. Note this uses a suboptimal undifferentiable unbatched implementation provided by `sigpy`. Optionally use `cupy` to accelerate on GPU, only if `cupy` is installed and a GPU is available. :param torch.Tensor y: multi-coil kspace measurements with shape [B,2,N,...,H,W] where N is coil dimension. :param int calib_size: optional square auto-calibration size in pixels, used by `sigpy`. :param bool use_cupy: whether to attempt to use cupy for GPU acceleration. :return: torch.Tensor of coil maps of complex dtype and shape [B,N,...,H,W] """ try: from sigpy.mri.app import EspiritCalib import sigpy as sp # pragma: no cover except ImportError: # pragma: no cover raise ImportError( "sigpy is required to estimate sens maps. Install it using pip install sigpy" ) if use_cupy: # pragma: no cover try: import cupy as cp use_cupy = cp.cuda.is_available() except ImportError: warn( "cupy is not installed, using cpu for coil map estimation. Install cupy to speed up computation." ) use_cupy = False complex_y = MRIMixin.to_torch_complex(y) if use_cupy: # pragma: no cover if y.device.type == "cuda": cupy_y = cp.from_dlpack(complex_y) else: cupy_y = cp.from_dlpack(complex_y.to("cuda")) cupy_maps = cp.stack( [ EspiritCalib( yb, calib_size, show_pbar=False, device=cupy_y.device ).run() for yb in cupy_y ] ) torch_maps = torch.from_dlpack(cupy_maps) else: device = sp.Device(-1) maps = np.stack( [ EspiritCalib(yb, calib_size, show_pbar=False, device=device).run() for yb in complex_y.numpy(force=True) ] ) torch_maps = torch.from_numpy(maps) return torch_maps
[docs] class DynamicMRI(MRI, TimeMixin): r""" Single-coil accelerated dynamic magnetic resonance imaging. The linear operator operates in 2D+t videos and is defined as .. math:: y_t = M_t Fx_t where :math:`M_t` applies a time-varying mask, and :math:`F` is the 2D discrete Fourier Transform. This operator has a simple singular value decomposition, so it inherits the structure of :class:`deepinv.physics.DecomposablePhysics` and thus have a fast pseudo-inverse and prox operators. The complex images :math:`x` and measurements :math:`y` should be of size (B, 2, T, H, W) where the first channel corresponds to the real part and the second channel corresponds to the imaginary part. A fixed mask can be set at initialisation, or a new mask can be set either at forward (using ``physics(x, mask=mask)``) or using ``update``. .. note:: We provide various random mask generators (e.g. Cartesian undersampling) that can be used directly with this physics. See e.g. :class:`deepinv.physics.generator.mri.RandomMaskGenerator` :param torch.Tensor mask: binary mask, where 1s represent sampling locations, and 0s otherwise. The mask size can either be (H,W), (T,H,W), (C,T,H,W) or (B,C,T,H,W) where H, W are the image height and width, T is time-steps, C is channels (typically 2) and B is batch size. :param tuple img_size: if mask not specified, flat mask of ones is created using ``img_size``, where ``img_size`` can be of any shape specified above. If mask provided, ``img_size`` is ignored. :param torch.device device: cpu or gpu. |sep| :Examples: Single-coil accelerated 2D+t MRI operator: >>> from deepinv.physics import DynamicMRI >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn(1, 2, 2, 2, 2) # Define random video of shape (B,C,T,H,W) >>> mask = torch.rand_like(x) > 0.75 # Define random 4x subsampling mask >>> physics = DynamicMRI(mask=mask) # Physics with given mask >>> physics.update(mask=mask) # Alternatively set mask on-the-fly >>> physics(x) tensor([[[[[-0.0000, 0.7969], [-0.0000, -0.0000]], <BLANKLINE> [[-0.0000, -1.9860], [-0.0000, -0.4453]]], <BLANKLINE> <BLANKLINE> [[[ 0.0000, 0.0000], [-0.8137, -0.0000]], <BLANKLINE> [[-0.0000, -0.0000], [-0.0000, 1.1135]]]]]) """ def A(self, x: Tensor, mask: Tensor = None, **kwargs) -> torch.Tensor: mask = self.check_mask(self.mask if mask is None else mask).to(x.device) mask_flatten = self.flatten(mask.expand(*x.shape)) y = self.unflatten( super().A(self.flatten(x), mask_flatten, check_mask=False), batch_size=x.shape[0], ) self.update_parameters(mask=mask, check_mask=False, **kwargs) return y
[docs] def A_adjoint( self, y: Tensor, mask: Tensor = None, mag: bool = False, **kwargs ) -> Tensor: """Adjoint operator. Optionally perform magnitude to reduce channel dimension. :param torch.Tensor y: input kspace of shape `(B,2,T,H,W)` :param torch.Tensor mask: optionally set mask on-the-fly, see class docs for shapes allowed. :param bool mag: perform complex magnitude. """ mask = self.check_mask(self.mask if mask is None else mask).to(y.device) mask_flatten = self.flatten(mask.expand(*y.shape)) x = self.unflatten( super().A_adjoint( self.flatten(y), mask=mask_flatten, check_mask=False, mag=mag ), batch_size=y.shape[0], ) self.update_parameters(mask=mask, check_mask=False, **kwargs) return x
def A_dagger(self, y: Tensor, mask: Tensor = None, **kwargs) -> torch.Tensor: return self.A_adjoint(y, mask=mask, **kwargs)
[docs] def check_mask(self, mask: torch.Tensor = None, **kwargs) -> None: r""" Updates MRI mask and verifies mask shape to be B,C,T,H,W. :param torch.nn.parameter.Parameter, float MRI subsampling mask. """ while mask is not None and len(mask.shape) < 5: # to B,C,T,H,W mask = mask.unsqueeze(0) return super().check_mask(mask=mask, three_d=self.three_d)
[docs] def noise(self, x, **kwargs): r""" Incorporates noise into the measurements :math:`\tilde{y} = N(y)` :param torch.Tensor x: clean measurements :return torch.Tensor: noisy measurements """ return self.noise_model(x, **kwargs) * self.mask
[docs] def to_static( self, mask: torch.Tensor | None = None, device: str | torch.device = "cpu" ) -> MRI: """Convert dynamic MRI to static MRI by removing time dimension. :param torch.Tensor mask: new static MRI mask. If None, existing mask is flattened (summed) along the time dimension. :return MRI: static MRI physics """ return MRI( mask=torch.clip(self.mask.sum(2), 0.0, 1.0) if mask is None else mask, img_size=self.img_size, device=device, )
[docs] class SequentialMRI(DynamicMRI): r""" Single-coil accelerated magnetic resonance imaging using sequential sampling. Let :math:`M` be a subsampling mask with given acceleration. :math:`M_t` is a time-varying mask with the sequential sampling pattern e.g. non-overlapping lines or spokes, such that :math:`S=\bigcup_t S_t`. The sequential MRI operator then simulates a time sequence of k-space samples: .. math:: y_t = M_t F x where :math:`F` is the 2D discrete Fourier Transform, the image :math:`x` is of shape (B, 2, H, W) and measurements :math:`y` is of shape (B, 2, T, H, W) where the first channel corresponds to the real part and the second channel corresponds to the imaginary part. This operator has a simple singular value decomposition, so it inherits the structure of :class:`deepinv.physics.DecomposablePhysics` and thus have a fast pseudo-inverse and prox operators. A fixed mask can be set at initialisation, or a new mask can be set either at forward (using ``physics(x, mask=mask)``) or using ``update``. .. note:: We provide various random mask generators (e.g. Cartesian undersampling) that can be used directly with this physics. See e.g. :class:`deepinv.physics.generator.mri.RandomMaskGenerator` :param torch.Tensor mask: binary mask :math:`S_t,t=1\ldots T`, where 1s represent sampling locations, and 0s otherwise. The mask size can either be (H,W), (T,H,W), (C,T,H,W) or (B,C,T,H,W) where H, W are the image height and width, T is time-steps, C is channels (typically 2) and B is batch size. :param tuple img_size: if mask not specified, flat mask of ones is created using ``img_size``, where ``img_size`` can be of any shape specified above. If mask provided, ``img_size`` is ignored. :param torch.device device: cpu or gpu. |sep| :Examples: Single-coil accelerated sequential MRI operator: >>> from deepinv.physics import SequentialMRI >>> x = torch.randn(1, 2, 2, 2) # Define random image of shape (B,C,H,W) >>> mask = torch.zeros(1, 2, 3, 2, 2) # Empty demo time-varying mask with 3 frames >>> physics = SequentialMRI(mask=mask) # Physics with given mask >>> physics.update(mask=mask) # Alternatively set mask on-the-fly >>> physics(x).shape # MRI sequential samples torch.Size([1, 2, 3, 2, 2]) """ def A(self, x: Tensor, mask: Tensor = None, **kwargs) -> torch.Tensor: return super().A( self.repeat(x, self.mask if mask is None else mask), mask, **kwargs )
[docs] def A_adjoint( self, y: Tensor, mask: Tensor = None, keep_time_dim=False, **kwargs ) -> torch.Tensor: r""" Computes the adjoint of the forward operator :math:`\tilde{x} = A^{\top}y`. :param torch.Tensor y: input tensor :param torch.nn.parameter.Parameter, float mask: input mask :param bool keep_time_dim: if ``True``, adjoint is calculated frame-by-frame. Used for visualisation. If ``False``, flatten the time dimension before calculating. :return: (:class:`torch.Tensor`) output tensor """ if keep_time_dim: return super().A_adjoint(y, mask, **kwargs) else: mask = mask if mask is not None else self.mask return self.to_static(device=y.device).A_adjoint( self.average(y, mask), mask=self.average(mask), **kwargs )