.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/self-supervised-learning/demo_equivariant_imaging.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_self-supervised-learning_demo_equivariant_imaging.py: Self-supervised learning with Equivariant Imaging for MRI. ==================================================================================================== This example shows you how to train a reconstruction network for an MRI inverse problem on a fully self-supervised way, i.e., using measurement data only. The equivariant imaging loss is presented in `"Equivariant Imaging: Learning Beyond the Range Space" `_. .. GENERATED FROM PYTHON SOURCE LINES 11-22 .. code-block:: Python from pathlib import Path import torch from torch.utils.data import DataLoader from torchvision import transforms import deepinv as dinv from deepinv.datasets import SimpleFastMRISliceDataset from deepinv.utils.demo import get_data_home, load_degradation, demo_mri_model from deepinv.models.utils import get_weights_url .. GENERATED FROM PYTHON SOURCE LINES 23-26 Setup paths for data loading and results. --------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 26-36 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" CKPT_DIR = BASE_DIR / "ckpts" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 37-56 Load base image datasets and degradation operators. ---------------------------------------------------------------------------------- In this example, we use a mini demo subset of the single-coil `FastMRI dataset `_ as the base image dataset, consisting of 2 knee images of size 320x320. .. seealso:: Datasets :class:`deepinv.datasets.FastMRISliceDataset` :class:`deepinv.datasets.SimpleFastMRISliceDataset` We provide convenient datasets to easily load both raw and reconstructed FastMRI images. You can download more data on the `FastMRI site `_. .. important:: By using this dataset, you confirm that you have agreed to and signed the `FastMRI data use agreement `_. .. note:: We reduce to the size to 128x128 for faster training in the demo. .. GENERATED FROM PYTHON SOURCE LINES 56-69 .. code-block:: Python operation = "MRI" img_size = 128 transform = transforms.Compose([transforms.Resize(img_size)]) train_dataset = SimpleFastMRISliceDataset( get_data_home(), transform=transform, train_percent=0.5, train=True, download=True ) test_dataset = SimpleFastMRISliceDataset( get_data_home(), transform=transform, train_percent=0.5, train=False ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/820529 [00:00`_. See :func:`deepinv.utils.demo.demo_mri_model` for details. .. GENERATED FROM PYTHON SOURCE LINES 111-115 .. code-block:: Python model = demo_mri_model(device=device) .. GENERATED FROM PYTHON SOURCE LINES 116-131 Set up the training parameters -------------------------------------------- We choose a self-supervised training scheme with two losses: the measurement consistency loss (MC) and the equivariant imaging loss (EI). The EI loss requires a group of transformations to be defined. The forward model `should not be equivariant to these transformations `_. Here we use the group of 4 rotations of 90 degrees, as the accelerated MRI acquisition is not equivariant to rotations (while it is equivariant to translations). See :ref:`docs ` for full list of available transforms. .. note:: We use a pretrained model to reduce training time. You can get the same results by training from scratch for 150 epochs using a larger knee dataset of ~1000 images. .. GENERATED FROM PYTHON SOURCE LINES 131-156 .. code-block:: Python epochs = 1 # choose training epochs learning_rate = 5e-4 batch_size = 16 if torch.cuda.is_available() else 1 # choose self-supervised training losses # generates 4 random rotations per image in the batch losses = [dinv.loss.MCLoss(), dinv.loss.EILoss(dinv.transform.Rotate(n_trans=4))] # choose optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8) + 1) # start with a pretrained model to reduce training time file_name = "new_demo_ei_ckp_150_v3.pth" url = get_weights_url(model_name="demo", file_name=file_name) ckpt = torch.hub.load_state_dict_from_url( url, map_location=lambda storage, loc: storage, file_name=file_name, ) # load a checkpoint to reduce training time model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/demo/resolve/main/new_demo_ei_ckp_150_v3.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/new_demo_ei_ckp_150_v3.pth 0%| | 0.00/2.17M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_equivariant_imaging.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_equivariant_imaging.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_