.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/self-supervised-learning/demo_artifact2artifact.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_self-supervised-learning_demo_artifact2artifact.py: Self-supervised MRI reconstruction with Artifact2Artifact ========================================================= We demonstrate the self-supervised Artifact2Artifact loss for solving an undersampled sequential MRI reconstruction problem without ground truth. The Artifact2Artifact loss was introduced in Liu et al. `RARE: Image Reconstruction using Deep Priors Learned without Groundtruth `__. In our example, we use it to reconstruct **static** images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled lines such that the whole measurement is a set of non-overlapping lines. For a description of how Artifact2Artifact constructs the loss, see :class:`deepinv.loss.Artifact2ArtifactLoss`. Note in our implementation, this is a special case of the generic splitting loss: see :class:`deepinv.loss.SplittingLoss` for more details. See :class:`deepinv.loss.Phase2PhaseLoss` for the related Phase2Phase. .. GENERATED FROM PYTHON SOURCE LINES 26-44 .. code-block:: Python from pathlib import Path import torch from torch.utils.data import DataLoader, Subset from torchvision import transforms import deepinv as dinv from deepinv.datasets import SimpleFastMRISliceDataset from deepinv.utils.demo import demo_mri_model, get_data_home from deepinv.models.utils import get_weights_url from deepinv.physics.generator import ( GaussianMaskGenerator, BernoulliSplittingMaskGenerator, ) torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 45-48 Load data --------- .. GENERATED FROM PYTHON SOURCE LINES 48-84 .. code-block:: Python # In this example, we use a mini demo subset of the single-coil `FastMRI dataset `_ # as the base image dataset, consisting of knees of size 320x320, and then resized to 128x128 for speed. # # .. important:: # # By using this dataset, you confirm that you have agreed to and signed the `FastMRI data use agreement `_. # # .. seealso:: # # Datasets :class:`deepinv.datasets.FastMRISliceDataset` :class:`deepinv.datasets.SimpleFastMRISliceDataset` # We provide convenient datasets to easily load both raw and reconstructed FastMRI images. # You can download more data on the `FastMRI site `_. # # # We use a train set of size 1 and test set of size 1 in this demo for # speed to fine-tune the original model. To train the original # model from scratch, use a larger dataset of size ~150. # batch_size = 1 H = 128 transform = transforms.Compose([transforms.Resize(H)]) train_dataset = SimpleFastMRISliceDataset( get_data_home(), transform=transform, train=True, download=True, train_percent=0.5 ) test_dataset = SimpleFastMRISliceDataset( get_data_home(), transform=transform, train=False, train_percent=0.5 ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) .. GENERATED FROM PYTHON SOURCE LINES 85-97 Define physics -------------- We simulate a sequential k-space sampler, that, over the course of 4 phases (i.e. frames), samples 64 lines (i.e 2x total undersampling from 128) with Gaussian weighting (plus a few extra for the ACS signals in the centre of the k-space). We use :class:`deepinv.physics.SequentialMRI` to do this. First, we define a static 2x acceleration mask that all measurements use (of shape [B,C,H,W]): .. GENERATED FROM PYTHON SOURCE LINES 97-103 .. code-block:: Python mask_full = GaussianMaskGenerator((2, H, H), acceleration=2, device=device).step( batch_size=batch_size )["mask"] .. GENERATED FROM PYTHON SOURCE LINES 104-107 Next, we randomly share the sampled lines across 4 time-phases into a time-varying mask: .. GENERATED FROM PYTHON SOURCE LINES 107-132 .. code-block:: Python # Split only in horizontal direction masks = [mask_full[..., 0, :]] splitter = BernoulliSplittingMaskGenerator((2, H), split_ratio=0.5, device=device) acs = 10 # Split 4 times for _ in range(2): new_masks = [] for m in masks: m1 = splitter.step(batch_size=batch_size, input_mask=m)["mask"] m2 = m - m1 m1[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1 m2[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1 new_masks.extend([m1, m2]) masks = new_masks # Merge masks into time dimension mask = torch.stack(masks, 2) # Convert to vertical lines mask = torch.stack([mask] * H, -2) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/deepinv/deepinv/deepinv/physics/generator/inpainting.py:110: UserWarning: Generating pixelwise mask assumes channel in first dimension. For 2D images (i.e. of shape (H,W)) ensure tensor_size is at least 3D (i.e. C,H,W). However, for tensor_size of shape (C,M), this will work as expected. warn( .. GENERATED FROM PYTHON SOURCE LINES 133-135 Now define physics using this time-varying mask of shape [B,C,T,H,W]: .. GENERATED FROM PYTHON SOURCE LINES 135-139 .. code-block:: Python physics = dinv.physics.SequentialMRI(mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 140-144 Let's visualise the sequential measurements using a sample image (run this notebook yourself to display the video). We also visualise the frame-by-frame no-learning zero-filled reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 144-154 .. code-block:: Python x = next(iter(train_dataloader)) y = physics(x) dinv.utils.plot_videos( [physics.repeat(x, mask), y, mask, physics.A_adjoint(y, keep_time_dim=True)], titles=["x", "y", "mask", "x_init"], display=True, ) .. image-sg:: /auto_examples/self-supervised-learning/images/sphx_glr_demo_artifact2artifact_001.png :alt: x, y, mask, x_init :srcset: /auto_examples/self-supervised-learning/images/sphx_glr_demo_artifact2artifact_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/deepinv/deepinv/deepinv/utils/plotting.py:792: UserWarning: IPython can't be found. Install it to use display=True. Skipping... warn("IPython can't be found. Install it to use display=True. Skipping...") /opt/hostedtoolcache/Python/3.9.21/x64/lib/python3.9/site-packages/matplotlib/animation.py:872: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 155-159 Also visualise the flattened time-series, recovering the original 2x undersampling mask (note the actual undersampling factor is much lower due to ACS lines): .. GENERATED FROM PYTHON SOURCE LINES 159-168 .. code-block:: Python dinv.utils.plot( [x, physics.average(y), physics.average(mask), physics.A_adjoint(y)], titles=["x", "y", "orig mask", "x_init"], ) print("Total acceleration:", (2 * 128 * 128) / mask.sum()) .. image-sg:: /auto_examples/self-supervised-learning/images/sphx_glr_demo_artifact2artifact_002.png :alt: x, y, orig mask, x_init :srcset: /auto_examples/self-supervised-learning/images/sphx_glr_demo_artifact2artifact_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Total acceleration: tensor(1.3617) .. GENERATED FROM PYTHON SOURCE LINES 169-178 Define model ------------ As a (static) reconstruction network, we use an unrolled network (half-quadratic splitting) with a trainable denoising prior based on the DnCNN architecture as an example of a model-based deep learning architecture from `MoDL `_. See :func:`deepinv.utils.demo.demo_mri_model` for details. .. GENERATED FROM PYTHON SOURCE LINES 178-182 .. code-block:: Python model = demo_mri_model(device=device) .. GENERATED FROM PYTHON SOURCE LINES 183-191 Prep loss --------- Perform loss on all collected lines by setting ``dynamic_model`` to False. Then adapt model to perform Artifact2Artifact. We set ``split_size=1`` to mean that each Artifact chunk containes only 1 frame. .. GENERATED FROM PYTHON SOURCE LINES 191-198 .. code-block:: Python loss = dinv.loss.Artifact2ArtifactLoss( (2, 4, H, H), split_size=1, dynamic_model=False, device=device ) model = loss.adapt_model(model) .. GENERATED FROM PYTHON SOURCE LINES 199-207 Train model ----------- Original model is trained for 100 epochs. We demonstrate loading the pretrained model then fine-tuning with 1 epoch. Report PSNR and SSIM. To train from scratch, simply comment out the model loading code and increase the number of epochs. .. GENERATED FROM PYTHON SOURCE LINES 207-240 .. code-block:: Python optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8) # Load pretrained model file_name = "demo_artifact2artifact_mri.pth" url = get_weights_url(model_name="measplit", file_name=file_name) ckpt = torch.hub.load_state_dict_from_url( url, map_location=lambda storage, loc: storage, file_name=file_name ) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) # Initialize the trainer trainer = dinv.Trainer( model, physics=physics, epochs=1, losses=loss, optimizer=optimizer, train_dataloader=train_dataloader, metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()], online_measurements=True, device=device, save_path=None, verbose=True, wandb_vis=False, show_progress_bar=False, ) model = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_artifact2artifact_mri.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_artifact2artifact_mri.pth 0%| | 0.00/2.16M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_artifact2artifact.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_artifact2artifact.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_