FastMRISliceDataset#
- class deepinv.datasets.FastMRISliceDataset(root, load_metadata_from_cache=False, save_metadata_to_cache=False, metadata_cache_file='dataset_cache.pkl', slice_index='all', subsample_volumes=1.0, transform=None, filter_id=None, rng=None)[source]#
-
Dataset for fastMRI that provides access to raw MR image slices.
This dataset randomly selects 2D slices from a dataset of 3D MRI volumes. This class considers one data sample as one slice of a MRI scan, thus slices of the same MRI scan are considered independently in the dataset.
To download raw data, please go to the bottom of the page
https://fastmri.med.nyu.edu/
to download the brain/knee and train/validation/test volumes ash5
files.The dataset is loaded as tuples
(x, y)
wherey
are the kspace measurements of shape(2, (N,) H, W)
where N is the optional coil dimension depending on whether the data is singlecoil or multicoil, andx
(“target”) are the magnitude root-sum-square reconstructions of shape(1, H, W)
.If
transform
is used ormask
exists in file, then also returnsparams
dict containing e.g.mask
and/orcoil_maps
.Tip
x
andy
are related bydeepinv.physics.MRI.A_adjoint()
ordeepinv.physics.MultiCoilMRI.A_adjoint()
depending on ify
are multicoil or not, withcrop=True, rss=True
.See the fastMRI README for more details.
Raw data file structure:
self.root --- file1000005.h5 | -- xxxxxxxxxxx.h5
Each file contains the k-space data, reconstructed images and some metadata related to the scan. When using this class, consider using the
metadata_cache
options to speed up class initialisation after the first initialisation.Note
We also provide a simple FastMRI dataset class in
deepinv.datasets.fastmri.SimpleFastMRISliceDataset
. This allows you to save and load the dataset as 2D singlecoil slices much faster and all in-memory. You can generate this using the methodsave_simple_dataset
.Important
By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.
- Parameters:
root (Union[str, pathlib.Path]) – Path to the dataset.
load_metadata_from_cache (bool) – Whether to load dataset metadata from cache.
save_metadata_to_cache (bool) – Whether to cache dataset metadata.
metadata_cache_file (Union[str, pathlib.Path]) – A file used to cache dataset information for faster load times.
subsample_volumes (float) – (optional) proportion of volumes to be randomly subsampled (float between 0 and 1).
slice_index (str, int, tuple) – if
"all"
, keep all slices per volume, ifint
, keep only that indexed slice per volume, ifint
ortuple[int]
, index those slices, if"middle"
, keep the middle slice, if"middle+i"
, keep \(2i+1\) about middle slice, if"random"
, select random slice. Defaults to"all"
.transform (Callable) – optional transform function taking in (multicoil) kspace of shape (2, (N,) H, W) and targets of shape (1, H, W).
See also
deepinv.datasets.MRISliceTransform
Transform for working with raw data: simulate masks and estimate coil maps.
- Parameters:
filter_id (Callable) – optional function that takes
SliceSampleID
named tuple and returns whether this id should be included.rng (torch.Generator, None) – optional torch random generator for shuffle slice indices
- Examples:
Instantiate dataset with sample data (from a demo multicoil brain volume):
>>> from deepinv.datasets import FastMRISliceDataset, download_archive >>> from deepinv.utils import get_image_url, get_data_home >>> url = get_image_url("demo_fastmri_brain_multicoil.h5") >>> root = get_data_home() / "fastmri" / "brain" >>> download_archive(url, root / "demo.h5") >>> dataset = FastMRISliceDataset(root=root, slice_index="all") >>> len(dataset) 16 >>> target, kspace = dataset[0] >>> target.shape # (1, W, W), varies per sample torch.Size([1, 213, 213]) >>> kspace.shape # (2, N, H, W), varies per sample torch.Size([2, 4, 512, 213])
Load one slice per volume:
>>> dataset = FastMRISliceDataset(root=root, slice_index=0)
Use MRI transform to mask, estimate sensitivity maps, normalise and/or crop:
>>> from deepinv.datasets import MRISliceTransform >>> from deepinv.physics.generator import GaussianMaskGenerator >>> mask_generator = GaussianMaskGenerator((512, 213)) >>> dataset = FastMRISliceDataset(root, transform=MRISliceTransform(mask_generator=mask_generator, estimate_coil_maps=True)) >>> target, kspace, params = dataset[0] >>> params["mask"].shape torch.Size([1, 512, 213]) >>> params["coil_maps"].shape torch.Size([4, 512, 213])
Filter by volume ID:
>>> dataset = FastMRISliceDataset(root, filter_id=lambda s: "brain" in str(s.fname)) >>> len(dataset) 16
Convert to a simple normalised padded in-memory slice dataset from the middle slices only:
>>> simple_set = FastMRISliceDataset(root=root, slice_index="middle").save_simple_dataset(root.parent / "simple_set.pt") >>> len(simple_set) 1
Instantiate dataset with metadata cache (speeds up subsequent instantiation):
>>> dataset = FastMRISliceDataset(root=root, load_metadata_from_cache=True, save_metadata_to_cache=True, metadata_cache_file=root.parent / "cache.pkl") Saving dataset cache to ... >>> import shutil; shutil.rmtree(root.parent)
- class SliceSampleID(fname, slice_ind, metadata)[source]#
Bases:
NamedTuple
Data structure containing ID and metadata of specific slices within MRI data files.
- metadata_cache_manager(root, samples)[source]#
Read/write metadata cache file for populating list of sample ids.
- Parameters:
root (Union[str, pathlib.Path]) – root dir to save to metadata cache
samples (Any) – iterable (list, dict etc.) for populating with samples to read/write to metadata cache
- Yield:
samples, either populated from metadata cache, or blank, to be yielded to be written to.
- save_simple_dataset(dataset_path, pad_to_size=(320, 320), to_complex=False)[source]#
Convert dataset to a 2D singlecoil dataset and save as pickle file.
This allows the dataset to be loaded in memory with
deepinv.datasets.fastmri.SimpleFastMRISliceDataset
.- Example:
Load local brain dataset and convert to simple dataset
from deepinv.datasets import FastMRISliceDataset root = "/path/to/dataset/fastMRI/brain/multicoil_train" dataset = FastMRISliceDataset(root=root, slice_index="middle") subset = dataset.save_simple_dataset(root + "/fastmri_brain_singlecoil.pt")
- Parameters:
- Returns:
loaded SimpleFastMRISliceDataset
- Return type:
- static torch_shuffle(x, generator=None)[source]#
Shuffle list reproducibly using torch generator.
- Parameters:
x (list) – list to be shuffled
generator (torch.Generator) – torch Generator.
- Return list:
shuffled list
- Return type:
Examples using FastMRISliceDataset
:#

Self-supervised learning with Equivariant Imaging for MRI.