.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unfolded/demo_learned_primal_dual.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unfolded_demo_learned_primal_dual.py: Learned Primal-Dual algorithm for CT scan. ==================================================================================================== Implementation of the Unfolded Primal-Dual algorithm from :footcite:t:`adler2018learned`. where both the data fidelity and the prior are learned modules, distinct for each iterations. The algorithm is used for CT reconstruction trained on random phantoms. The phantoms are generated on the fly during training using the odl library (https://odlgroup.github.io/odl/). .. GENERATED FROM PYTHON SOURCE LINES 11-23 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch from torch.utils.data import DataLoader from deepinv.optim import BaseOptim from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset from deepinv.optim.optim_iterators import CPIteration, fStep, gStep from deepinv.models import PDNet_PrimalBlock, PDNet_DualBlock from deepinv.optim import Prior, DataFidelity from deepinv.models.utils import get_weights_url .. GENERATED FROM PYTHON SOURCE LINES 24-27 Setup paths for data loading and results. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-38 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" CKPT_DIR = BASE_DIR / "ckpts" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 5057.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 39-42 Load degradation operator. --------------------------------------------------- We consider the CT operator. .. GENERATED FROM PYTHON SOURCE LINES 42-61 .. code-block:: Python img_size = 64 n_channels = 1 # 3 for color images, 1 for gray-scale images operation = "CT" # Degradation parameters noise_level_img = 0.01 # Generate the CT operator. physics = dinv.physics.Tomography( img_width=img_size, angles=30, circle=False, device=device, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), normalize=True, ) .. rst-class:: sphx-glr-script-out .. code-block:: none Power iteration converged at iteration 9, ||A^T A||_2=1854.39 .. GENERATED FROM PYTHON SOURCE LINES 62-68 Define a custom iterator for the PDNet learned primal-dual algorithm. --------------------------------------------------------------------- The iterator is a subclass of the Chambolle-Pock iterator :class:`deepinv.optim.optim_iterators.CPIteration`. In PDNet, the primal (gStep) and dual (fStep) updates are directly replaced by neural networks. We thus redefine the fStep and gStep classes as simple proximal operators of the data fidelity and prior, respectively. Afterwards, both the data fidelity and the prior proximal operators are defined as trainable models. .. GENERATED FROM PYTHON SOURCE LINES 68-126 .. code-block:: Python class PDNetIteration(CPIteration): r"""Single iteration of learned primal dual. We only redefine the fStep and gStep classes. The forward method is inherited from the CPIteration class. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.g_step = gStepPDNet(**kwargs) self.f_step = fStepPDNet(**kwargs) class fStepPDNet(fStep): r""" Dual update of the PDNet algorithm. We write it as a proximal operator of the data fidelity term. This proximal mapping is to be replaced by a trainable model. """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, x, w, cur_data_fidelity, y, *args): r""" :param torch.Tensor x: Current first variable :math:`u`. :param torch.Tensor w: Current second variable :math:`A z`. :param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term. :param torch.Tensor y: Input data. """ return cur_data_fidelity.prox(x, w, y) class gStepPDNet(gStep): r""" Primal update of the PDNet algorithm. We write it as a proximal operator of the prior term. This proximal mapping is to be replaced by a trainable model. """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, x, w, cur_prior, *args): r""" :param torch.Tensor x: Current first variable :math:`x`. :param torch.Tensor w: Current second variable :math:`A^\top u`. :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior. """ return cur_prior.prox(x, w) class PDNet_optim(BaseOptim): def __init__(self, **kwargs): super(PDNet_optim, self).__init__(PDNetIteration(), **kwargs) .. GENERATED FROM PYTHON SOURCE LINES 127-131 Define the trainable prior and data fidelity terms. --------------------------------------------------- Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`. Their proximal operators are replaced by trainable models. .. GENERATED FROM PYTHON SOURCE LINES 131-166 .. code-block:: Python class PDNetPrior(Prior): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self.model = model def prox(self, x, w): return self.model(x, w[:, 0:1, :, :]) class PDNetDataFid(DataFidelity): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self.model = model def prox(self, x, w, y): return self.model(x, w[:, 1:2, :, :], y) # Unrolled optimization algorithm parameters max_iter = 10 # Set up the data fidelity term. Each layer has its own data fidelity module. data_fidelity = [ PDNetDataFid(model=PDNet_DualBlock().to(device)) for i in range(max_iter) ] # Set up the trainable prior. Each layer has its own prior module. prior = [PDNetPrior(model=PDNet_PrimalBlock().to(device)) for i in range(max_iter)] # Logging parameters verbose = True .. GENERATED FROM PYTHON SOURCE LINES 167-170 Define the training parameters. ------------------------------- We use the Adam optimizer and the StepLR scheduler. .. GENERATED FROM PYTHON SOURCE LINES 170-183 .. code-block:: Python # training parameters epochs = 10 if torch.cuda.is_available() else 2 learning_rate = 1e-3 num_workers = 4 if torch.cuda.is_available() else 0 train_batch_size = 5 test_batch_size = 1 n_iter_training = int(1e4) if torch.cuda.is_available() else 100 n_data = 1 # number of channels in the input n_primal = 5 # extend the primal space n_dual = 5 # extend the dual space .. GENERATED FROM PYTHON SOURCE LINES 184-195 Define the unfolded trainable model. ------------------------------------- The original paper of the learned primal dual algorithm the authors used the adjoint operator in the primal update. However, the same authors (among others) find in the paper A. Hauptmann, J. Adler, S. Arridge, O. Öktem, Multi-scale learned iterative reconstruction, IEEE Transactions on Computational Imaging 6, 843-856, 2020. that using a filtered gradient can improve both the training speed and reconstruction quality significantly. Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step. .. GENERATED FROM PYTHON SOURCE LINES 195-227 .. code-block:: Python def custom_init(y, physics): x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1) u0 = torch.zeros_like(y).repeat(1, n_dual, 1, 1) return (x0, x0, u0) def custom_output(X): return X["est"][0][:, 1, :, :].unsqueeze(1) model = PDNet_optim( unfold=True, params_algo={"K": physics.A, "K_adjoint": physics.A_dagger}, trainable_params=[], data_fidelity=data_fidelity, prior=prior, max_iter=max_iter, custom_init=custom_init, get_output=custom_output, ) # choose optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99)) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=epochs ) # choose supervised training loss losses = [dinv.loss.SupLoss(metric=dinv.metric.MSE())] .. GENERATED FROM PYTHON SOURCE LINES 228-230 Training dataset of random phantoms. -------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 230-246 .. code-block:: Python # Define the base train and test datasets of clean images. train_dataset_name = "random_phantom" train_dataset = RandomPhantomDataset( size=img_size, n_data=1, length=n_iter_training // epochs ) test_dataset = SheppLoganDataset(size=img_size, n_data=1) train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=num_workers ) test_dataloader = DataLoader( test_dataset, batch_size=test_batch_size, num_workers=num_workers ) .. GENERATED FROM PYTHON SOURCE LINES 247-250 Train the network ---------------------------------------------------------------------------------------- We train the network using the library's train function. .. GENERATED FROM PYTHON SOURCE LINES 250-280 .. code-block:: Python trainer = dinv.Trainer( model, physics=physics, losses=losses, optimizer=optimizer, epochs=epochs, scheduler=scheduler, train_dataloader=train_dataloader, eval_dataloader=test_dataloader, device=device, online_measurements=True, save_path=str(CKPT_DIR / operation), verbose=verbose, show_progress_bar=False, # disable progress bar for better vis in sphinx gallery. ) # If working on CPU, start with a pretrained model to reduce training time if not torch.cuda.is_available(): file_name = "ckp_PDNet.pth" url = get_weights_url(model_name="demo", file_name=file_name) ckpt = torch.hub.load_state_dict_from_url( url, map_location=lambda storage, loc: storage, file_name=file_name ) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) model = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1352: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. self.setup_train() The model has 251980 trainable parameters Train epoch 0: TotalLoss=0.002, PSNR=29.6 Eval epoch 0: PSNR=21.799 Best model saved at epoch 1 Train epoch 1: TotalLoss=0.001, PSNR=31.491 Eval epoch 1: PSNR=22.731 Best model saved at epoch 2 Train epoch 2: TotalLoss=0.001, PSNR=31.99 Eval epoch 2: PSNR=23.423 Best model saved at epoch 3 Train epoch 3: TotalLoss=0.001, PSNR=32.198 Eval epoch 3: PSNR=23.908 Best model saved at epoch 4 Train epoch 4: TotalLoss=0.001, PSNR=32.486 Eval epoch 4: PSNR=24.173 Best model saved at epoch 5 Train epoch 5: TotalLoss=0.001, PSNR=32.657 Eval epoch 5: PSNR=25.174 Best model saved at epoch 6 Train epoch 6: TotalLoss=0.001, PSNR=32.768 Eval epoch 6: PSNR=25.164 Train epoch 7: TotalLoss=0.001, PSNR=32.884 Eval epoch 7: PSNR=25.019 Train epoch 8: TotalLoss=0.001, PSNR=32.907 Eval epoch 8: PSNR=25.305 Best model saved at epoch 9 Train epoch 9: TotalLoss=0.001, PSNR=33.191 Eval epoch 9: PSNR=25.508 Best model saved at epoch 10 .. GENERATED FROM PYTHON SOURCE LINES 281-285 Test the network -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 285-304 .. code-block:: Python trainer.test(test_dataloader) test_sample = next(iter(test_dataloader)) model.eval() test_sample = test_sample.to(device) # Get the measurements and the ground truth y = physics(test_sample) with torch.no_grad(): # it is important to disable gradient computation during testing. rec = model(y, physics=physics) backprojected = physics.A_adjoint(y) dinv.utils.plot( [backprojected, rec, test_sample], titles=["Linear", "Reconstruction", "Ground truth"], suptitle="Reconstruction results", ) .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_001.png :alt: Reconstruction results, Linear, Reconstruction, Ground truth :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1544: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. self.setup_train(train=False) Eval epoch 0: PSNR=25.396, PSNR no learning=12.19 Test results: PSNR no learning: 12.190 +- 0.002 PSNR: 25.396 +- 0.004 /local/jtachell/deepinv/deepinv/deepinv/utils/plotting.py:387: UserWarning: This figure was using a layout engine that is incompatible with subplots_adjust and/or tight_layout; not calling subplots_adjust. fig.subplots_adjust(top=0.75) .. GENERATED FROM PYTHON SOURCE LINES 305-308 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (49 minutes 58.120 seconds) .. _sphx_glr_download_auto_examples_unfolded_demo_learned_primal_dual.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_learned_primal_dual.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_learned_primal_dual.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_learned_primal_dual.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_