FastMRISliceDataset#

class deepinv.datasets.FastMRISliceDataset(root, load_metadata_from_cache=False, save_metadata_to_cache=False, metadata_cache_file='dataset_cache.pkl', slice_index='all', subsample_volumes=1.0, transform=None, filter_id=None, rng=None)[source]#

Bases: Dataset, MRIMixin

Dataset for fastMRI that provides access to raw MR image slices.

This dataset randomly selects 2D slices from a dataset of 3D MRI volumes. This class considers one data sample as one slice of a MRI scan, thus slices of the same MRI scan are considered independently in the dataset.

To download raw data, please go to the bottom of the page https://fastmri.med.nyu.edu/ to download the brain/knee and train/validation/test volumes as h5 files.

The dataset is loaded as tuples (x, y) where y are the kspace measurements of shape (2, (N,) H, W) where N is the optional coil dimension depending on whether the data is singlecoil or multicoil, and x (“target”) are the magnitude root-sum-square reconstructions of shape (1, H, W).

If transform is used or mask exists in file, then also returns params dict containing e.g. mask and/or coil_maps.

Tip

x and y are related by deepinv.physics.MRI.A_adjoint() or deepinv.physics.MultiCoilMRI.A_adjoint() depending on if y are multicoil or not, with crop=True, rss=True.

See the fastMRI README for more details.

Raw data file structure:

self.root --- file1000005.h5
           |
           -- xxxxxxxxxxx.h5

Each file contains the k-space data, reconstructed images and some metadata related to the scan. When using this class, consider using the metadata_cache options to speed up class initialisation after the first initialisation.

Note

We also provide a simple FastMRI dataset class in deepinv.datasets.fastmri.SimpleFastMRISliceDataset. This allows you to save and load the dataset as 2D singlecoil slices much faster and all in-memory. You can generate this using the method save_simple_dataset.

Important

By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.

Parameters:
  • root (Union[str, pathlib.Path]) – Path to the dataset.

  • load_metadata_from_cache (bool) – Whether to load dataset metadata from cache.

  • save_metadata_to_cache (bool) – Whether to cache dataset metadata.

  • metadata_cache_file (Union[str, pathlib.Path]) – A file used to cache dataset information for faster load times.

  • subsample_volumes (float) – (optional) proportion of volumes to be randomly subsampled (float between 0 and 1).

  • slice_index (str, int, tuple) – if "all", keep all slices per volume, if int, keep only that indexed slice per volume, if int or tuple[int], index those slices, if "middle", keep the middle slice, if "middle+i", keep \(2i+1\) about middle slice, if "random", select random slice. Defaults to "all".

  • transform (Callable) – optional transform function taking in (multicoil) kspace of shape (2, (N,) H, W) and targets of shape (1, H, W).

See also

deepinv.datasets.MRISliceTransform

Transform for working with raw data: simulate masks and estimate coil maps.

Parameters:
  • filter_id (Callable) – optional function that takes SliceSampleID named tuple and returns whether this id should be included.

  • rng (torch.Generator, None) – optional torch random generator for shuffle slice indices


Examples:

Instantiate dataset with sample data (from a demo multicoil brain volume):

>>> from deepinv.datasets import FastMRISliceDataset, download_archive
>>> from deepinv.utils import get_image_url, get_data_home
>>> url = get_image_url("demo_fastmri_brain_multicoil.h5")
>>> root = get_data_home() / "fastmri" / "brain"
>>> download_archive(url, root / "demo.h5")
>>> dataset = FastMRISliceDataset(root=root, slice_index="all")
>>> len(dataset)
16
>>> target, kspace = dataset[0]
>>> target.shape # (1, W, W), varies per sample
torch.Size([1, 213, 213])
>>> kspace.shape # (2, N, H, W), varies per sample
torch.Size([2, 4, 512, 213])

Load one slice per volume:

>>> dataset = FastMRISliceDataset(root=root, slice_index=0)

Use MRI transform to mask, estimate sensitivity maps, normalise and/or crop:

>>> from deepinv.datasets import MRISliceTransform
>>> from deepinv.physics.generator import GaussianMaskGenerator
>>> mask_generator = GaussianMaskGenerator((512, 213))
>>> dataset = FastMRISliceDataset(root, transform=MRISliceTransform(mask_generator=mask_generator, estimate_coil_maps=True))
>>> target, kspace, params = dataset[0]
>>> params["mask"].shape
torch.Size([1, 512, 213])
>>> params["coil_maps"].shape
torch.Size([4, 512, 213])

Filter by volume ID:

>>> dataset = FastMRISliceDataset(root, filter_id=lambda s: "brain" in str(s.fname))
>>> len(dataset)
16

Convert to a simple normalised padded in-memory slice dataset from the middle slices only:

>>> simple_set = FastMRISliceDataset(root=root, slice_index="middle").save_simple_dataset(root.parent / "simple_set.pt")
>>> len(simple_set)
1

Instantiate dataset with metadata cache (speeds up subsequent instantiation):

>>> dataset = FastMRISliceDataset(root=root, load_metadata_from_cache=True, save_metadata_to_cache=True, metadata_cache_file=root.parent / "cache.pkl")
Saving dataset cache to ...
>>> import shutil; shutil.rmtree(root.parent)
class SliceSampleID(fname, slice_ind, metadata)[source]#

Bases: NamedTuple

Data structure containing ID and metadata of specific slices within MRI data files.

fname: Path#

Alias for field number 0

metadata: Dict[str, Any]#

Alias for field number 2

slice_ind: int#

Alias for field number 1

metadata_cache_manager(root, samples)[source]#

Read/write metadata cache file for populating list of sample ids.

Parameters:
  • root (Union[str, pathlib.Path]) – root dir to save to metadata cache

  • samples (Any) – iterable (list, dict etc.) for populating with samples to read/write to metadata cache

Yield:

samples, either populated from metadata cache, or blank, to be yielded to be written to.

save_simple_dataset(dataset_path, pad_to_size=(320, 320), to_complex=False)[source]#

Convert dataset to a 2D singlecoil dataset and save as pickle file.

This allows the dataset to be loaded in memory with deepinv.datasets.fastmri.SimpleFastMRISliceDataset.

Example:

Load local brain dataset and convert to simple dataset

from deepinv.datasets import FastMRISliceDataset
root = "/path/to/dataset/fastMRI/brain/multicoil_train"
dataset = FastMRISliceDataset(root=root, slice_index="middle")
subset = dataset.save_simple_dataset(root + "/fastmri_brain_singlecoil.pt")
Parameters:
  • dataset_path (str) – desired path of dataset to be saved with file extension e.g. fastmri_knee_singlecoil.pt.

  • pad_to_size (bool) – if not None, normalise images to 0-1 then pad to provided shape. Must be set if images are of varying size, in order to successfully stack images to tensor.

Returns:

loaded SimpleFastMRISliceDataset

Return type:

SimpleFastMRISliceDataset

static torch_shuffle(x, generator=None)[source]#

Shuffle list reproducibly using torch generator.

Parameters:
Return list:

shuffled list

Return type:

list

Examples using FastMRISliceDataset:#

Tour of MRI functionality in DeepInverse

Tour of MRI functionality in DeepInverse

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning with Equivariant Imaging for MRI.