.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/self-supervised-learning/demo_scan_specific.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_self-supervised-learning_demo_scan_specific.py: Scan-specific zero-shot SSDU for MRI ===================================================== We demonstrate scan-specific self-supervised learning, that is, learning to reconstruct MRI scans from a single accelerated sample without ground truth. Here, we demonstrate fine-tuning a pretrained model (:class:`deepinv.models.RAM`) :footcite:p:`terris2025reconstruct` with the :class:`weighted SSDU ` loss :footcite:p:`millard2023theoretical,yaman2020self`. However, note that any of the :ref:`self-supervised losses ` can be used to do this with varying performance :footcite:p:`wang2025benchmarking`. For example see the :ref:`example using Equivariant Imaging ` :footcite:p:`chen2021equivariant`. Note that, if more data is available, better results can be obtained by fine-tuning on more samples! .. GENERATED FROM PYTHON SOURCE LINES 16-24 .. code-block:: Python import torch import deepinv as dinv device = dinv.utils.get_device() rng = torch.Generator(device=device).manual_seed(0) rng_cpu = torch.Generator(device="cpu").manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 4064.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 25-32 Data ---- First, download a demo single brain MRI volume (FLAIR sequence, SIEMENS Trio Tim 3T scanner) from the FastMRI brain dataset :footcite:p:`knoll2020advancing`, via HuggingFace. .. important:: By using this dataset, you confirm that you have agreed to and signed the `FastMRI data use agreement `_. .. GENERATED FROM PYTHON SOURCE LINES 32-40 .. code-block:: Python DATA_DIR = dinv.utils.get_data_home() / "fastMRI" / "multicoil_train" SLICE_DIR = DATA_DIR / "slices" DATA_DIR.mkdir(parents=True, exist_ok=True) SLICE_DIR.mkdir(exist_ok=True) dinv.utils.download_example("demo_fastmri_brain_multicoil.h5", DATA_DIR) .. GENERATED FROM PYTHON SOURCE LINES 41-49 We use the FastMRI slice dataset provided in DeepInverse to load the volume and return all 16 slices. The data is returned in the format `x, y, params` where `params` is a dictionary containing the acceleration mask (simulated Gaussian mask with acceleration 6) and the estimated coil sensitivity map. .. note:: This loading takes a few seconds per slice, as it must estimate the coil sensitivity map on the fly. .. GENERATED FROM PYTHON SOURCE LINES 49-65 .. code-block:: Python dataset = dinv.datasets.FastMRISliceDataset( DATA_DIR, slice_index="all", transform=dinv.datasets.MRISliceTransform( mask_generator=dinv.physics.generator.GaussianMaskGenerator( img_size=(256, 256), # this is overridden internally by true image size acceleration=6, center_fraction=0.08, device="cpu", rng=rng_cpu, ), estimate_coil_maps=True, ), ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/1 [00:000 to wait for more validation steps. .. GENERATED FROM PYTHON SOURCE LINES 178-182 Evaluation ---------- Now that the model is trained, we test the model on 3 samples by evaluating the model, plotting and saving the reconstructions and evaluation metrics. .. GENERATED FROM PYTHON SOURCE LINES 182-222 .. code-block:: Python from torch.utils.data._utils.collate import default_collate for i in [len(dataset) // 2 - 1, len(dataset) // 2, len(dataset) // 2 + 1]: # Load slice x, y, params = default_collate([dataset[i]]) x, y, params = ( x, y.to(device), { k: (v.to(device) if isinstance(v, torch.Tensor) else v) for (k, v) in params.items() }, ) physics.update(**params) # Compute baseline reconstructions x_adj = physics.A_adjoint(y).detach().cpu() x_dag = physics.A_dagger(y).detach().cpu() # Evaluate model with torch.no_grad(): x_hat = model(y, physics).detach().cpu() dinv.utils.plot( { "GT": x, "Adjoint": crop(x_adj, x), "SENSE": crop(x_dag, x), "Trained": crop(x_hat, x), }, subtitles=[ "", f"{metric(x_adj, x).item():.2f} dB", f"{metric(x_dag, x).item():.2f} dB", f"{metric(x_hat, x).item():.2f} dB", ], ) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_001.png :alt: GT, Adjoint, SENSE, Trained :srcset: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_002.png :alt: GT, Adjoint, SENSE, Trained :srcset: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_003.png :alt: GT, Adjoint, SENSE, Trained :srcset: /auto_examples/self-supervised-learning/images/sphx_glr_demo_scan_specific_003.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 223-226 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 3.765 seconds) .. _sphx_glr_download_auto_examples_self-supervised-learning_demo_scan_specific.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_scan_specific.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_scan_specific.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_scan_specific.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_