.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unfolded/demo_learned_primal_dual.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unfolded_demo_learned_primal_dual.py: Learned Primal-Dual algorithm for CT scan. ==================================================================================================== Implementation of the Unfolded Primal-Dual algorithm from Adler, Jonas, and Ozan Öktem. "Learned primal-dual reconstruction." IEEE transactions on medical imaging 37.6 (2018): 1322-1332. where both the data fidelity and the prior are learned modules, distinct for each iterations. The algorithm is used for CT reconstruction trained on random phantoms. The phantoms are generated on the fly during training using the odl library (https://odlgroup.github.io/odl/). .. GENERATED FROM PYTHON SOURCE LINES 15-26 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch from torch.utils.data import DataLoader from deepinv.unfolded import unfolded_builder from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset from deepinv.optim.optim_iterators import CPIteration, fStep, gStep from deepinv.models import PDNet_PrimalBlock, PDNet_DualBlock from deepinv.optim import Prior, DataFidelity .. GENERATED FROM PYTHON SOURCE LINES 27-30 Setup paths for data loading and results. ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 30-41 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" RESULTS_DIR = BASE_DIR / "results" CKPT_DIR = BASE_DIR / "ckpts" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 42-45 Load degradation operator. --------------------------------------------------- We consider the CT operator. .. GENERATED FROM PYTHON SOURCE LINES 45-63 .. code-block:: Python img_size = 128 if torch.cuda.is_available() else 32 n_channels = 1 # 3 for color images, 1 for gray-scale images operation = "CT" # Degradation parameters noise_level_img = 0.05 # Generate the CT operator. physics = dinv.physics.Tomography( img_width=img_size, angles=30, circle=False, device=device, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), ) .. GENERATED FROM PYTHON SOURCE LINES 64-70 Define a custom iterator for the PDNet learned primal-dual algorithm. --------------------------------------------------------------------- The iterator is a subclass of the Chambolle-Pock iterator :class:`deepinv.optim.optim_iterators.CPIteration`. In PDNet, the primal (gStep) and dual (fStep) updates are directly replaced by neural networks. We thus redefine the fStep and gStep classes as simple proximal operators of the data fidelity and prior, respectively. Afterwards, both the data fidelity and the prior proximal operators are defined as trainable models. .. GENERATED FROM PYTHON SOURCE LINES 70-123 .. code-block:: Python class PDNetIteration(CPIteration): r"""Single iteration of learned primal dual. We only redefine the fStep and gStep classes. The forward method is inherited from the CPIteration class. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.g_step = gStepPDNet(**kwargs) self.f_step = fStepPDNet(**kwargs) class fStepPDNet(fStep): r""" Dual update of the PDNet algorithm. We write it as a proximal operator of the data fidelity term. This proximal mapping is to be replaced by a trainable model. """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, x, w, cur_data_fidelity, y, *args): r""" :param torch.Tensor x: Current first variable :math:`u`. :param torch.Tensor w: Current second variable :math:`A z`. :param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term. :param torch.Tensor y: Input data. """ return cur_data_fidelity.prox(x, w, y) class gStepPDNet(gStep): r""" Primal update of the PDNet algorithm. We write it as a proximal operator of the prior term. This proximal mapping is to be replaced by a trainable model. """ def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, x, w, cur_prior, *args): r""" :param torch.Tensor x: Current first variable :math:`x`. :param torch.Tensor w: Current second variable :math:`A^\top u`. :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior. """ return cur_prior.prox(x, w) .. GENERATED FROM PYTHON SOURCE LINES 124-128 Define the trainable prior and data fidelity terms. --------------------------------------------------- Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`. Their proximal operators are replaced by trainable models. .. GENERATED FROM PYTHON SOURCE LINES 128-164 .. code-block:: Python class PDNetPrior(Prior): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self.model = model def prox(self, x, w): return self.model(x, w[:, 0:1, :, :]) class PDNetDataFid(DataFidelity): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self.model = model def prox(self, x, w, y): return self.model(x, w[:, 1:2, :, :], y) # Unrolled optimization algorithm parameters max_iter = 10 if torch.cuda.is_available() else 3 # number of unfolded layers # Set up the data fidelity term. Each layer has its own data fidelity module. data_fidelity = [ PDNetDataFid(model=PDNet_DualBlock().to(device)) for i in range(max_iter) ] # Set up the trainable prior. Each layer has its own prior module. prior = [PDNetPrior(model=PDNet_PrimalBlock().to(device)) for i in range(max_iter)] # Logging parameters verbose = True wandb_vis = False # plot curves and images in Weight&Bias .. GENERATED FROM PYTHON SOURCE LINES 165-168 Define the training parameters. ------------------------------- We use the Adam optimizer and the StepLR scheduler. .. GENERATED FROM PYTHON SOURCE LINES 168-181 .. code-block:: Python # training parameters epochs = 10 learning_rate = 1e-3 num_workers = 4 if torch.cuda.is_available() else 0 train_batch_size = 5 test_batch_size = 1 n_iter_training = int(1e4) if torch.cuda.is_available() else 100 n_data = 1 # number of channels in the input n_primal = 5 # extend the primal space n_dual = 5 # extend the dual space .. GENERATED FROM PYTHON SOURCE LINES 182-184 Define the model. ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 184-196 .. code-block:: Python def custom_init(y, physics): x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1) u0 = torch.zeros_like(y).repeat(1, n_dual, 1, 1) return {"est": (x0, x0, u0)} def custom_output(X): return X["est"][0][:, 1, :, :].unsqueeze(1) .. GENERATED FROM PYTHON SOURCE LINES 197-208 Define the unfolded trainable model. ------------------------------------- The original paper of the learned primal dual algorithm the authors used the adjoint operator in the primal update. However, the same authors (among others) find in the paper A. Hauptmann, J. Adler, S. Arridge, O. Öktem, Multi-scale learned iterative reconstruction, IEEE Transactions on Computational Imaging 6, 843-856, 2020. that using a filtered gradient can improve both the training speed and reconstruction quality significantly. Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step. .. GENERATED FROM PYTHON SOURCE LINES 208-228 .. code-block:: Python model = unfolded_builder( iteration=PDNetIteration(), params_algo={"K": physics.A, "K_adjoint": physics.A_dagger, "beta": 0.0}, data_fidelity=data_fidelity, prior=prior, max_iter=max_iter, custom_init=custom_init, get_output=custom_output, ) # choose optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99)) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=epochs ) # choose supervised training loss losses = [dinv.loss.SupLoss(metric=dinv.metric.MSE())] .. GENERATED FROM PYTHON SOURCE LINES 229-231 Training dataset of random phantoms. -------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 231-247 .. code-block:: Python # Define the base train and test datasets of clean images. train_dataset_name = "random_phantom" train_dataset = RandomPhantomDataset( size=img_size, n_data=1, length=n_iter_training // epochs ) test_dataset = SheppLoganDataset(size=img_size, n_data=1) train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=num_workers ) test_dataloader = DataLoader( test_dataset, batch_size=test_batch_size, num_workers=num_workers ) .. GENERATED FROM PYTHON SOURCE LINES 248-251 Train the network ---------------------------------------------------------------------------------------- We train the network using the library's train function. .. GENERATED FROM PYTHON SOURCE LINES 251-280 .. code-block:: Python method = "learned primal-dual" save_folder = RESULTS_DIR / method / operation plot_images = True # Images are saved in save_folder. plot_convergence_metrics = ( True # compute performance and convergence metrics along the algorithm. ) trainer = dinv.Trainer( model, physics=physics, losses=losses, optimizer=optimizer, epochs=epochs, scheduler=scheduler, train_dataloader=train_dataloader, eval_dataloader=test_dataloader, device=device, plot_convergence_metrics=plot_convergence_metrics, online_measurements=True, save_path=str(CKPT_DIR / operation), verbose=verbose, show_progress_bar=False, # disable progress bar for better vis in sphinx gallery. wandb_vis=wandb_vis, # training visualization can be done in Weight&Bias ) model = trainer.train() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_001.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_002.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_003.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_004.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_005.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_006.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_006.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_007.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_007.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_008.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_008.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_009.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_009.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_010.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_010.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none The model has 75595 trainable parameters Train epoch 0: TotalLoss=0.002, PSNR=28.082 Eval epoch 0: PSNR=19.298 Train epoch 1: TotalLoss=0.002, PSNR=26.705 Eval epoch 1: PSNR=19.373 Train epoch 2: TotalLoss=0.002, PSNR=29.883 Eval epoch 2: PSNR=19.866 Train epoch 3: TotalLoss=0.001, PSNR=29.65 Eval epoch 3: PSNR=19.622 Train epoch 4: TotalLoss=0.001, PSNR=30.77 Eval epoch 4: PSNR=19.531 Train epoch 5: TotalLoss=0.002, PSNR=27.483 Eval epoch 5: PSNR=19.844 Train epoch 6: TotalLoss=0.001, PSNR=30.088 Eval epoch 6: PSNR=19.975 Train epoch 7: TotalLoss=0.001, PSNR=31.011 Eval epoch 7: PSNR=20.059 Train epoch 8: TotalLoss=0.001, PSNR=29.679 Eval epoch 8: PSNR=20.057 Train epoch 9: TotalLoss=0.001, PSNR=31.864 Eval epoch 9: PSNR=20.049 .. GENERATED FROM PYTHON SOURCE LINES 281-285 Test the network -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 285-287 .. code-block:: Python trainer.test(test_dataloader) .. image-sg:: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_011.png :alt: PSNR, residual :srcset: /auto_examples/unfolded/images/sphx_glr_demo_learned_primal_dual_011.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Eval epoch 0: PSNR=20.068, PSNR no learning=-41.611 Test results: PSNR no learning: -41.611 +- 0.000 PSNR: 20.068 +- 0.000 {'PSNR no learning': np.float64(-41.61101531982422), 'PSNR no learning_std': 0, 'PSNR': np.float64(20.06831932067871), 'PSNR_std': 0} .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.693 seconds) .. _sphx_glr_download_auto_examples_unfolded_demo_learned_primal_dual.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_learned_primal_dual.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_learned_primal_dual.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_learned_primal_dual.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_