.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/models/demo_training.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_models_demo_training.py: Training a reconstruction model ==================================================================================================== This example provides a very simple quick start introduction to training reconstruction networks with DeepInverse for solving imaging inverse problems. Training requires these components, all of which you can define with DeepInverse: * A `model` to be trained from :ref:`reconstructors ` or define your own. * A `physics` from our :ref:`list of physics `. Or, :ref:`bring your own physics `. * A `dataset` of images and/or measurements from :ref:`datasets `. Or, :ref:`bring your own dataset `. * A `loss` from our :ref:`loss functions `. * A `metric` from our :ref:`metrics `. Here, we demonstrate a simple experiment of training a UNet on an inpainting task on the Urban100 dataset of natural images. .. GENERATED FROM PYTHON SOURCE LINES 20-27 .. code-block:: Python import deepinv as dinv import torch device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" rng = torch.Generator(device=device).manual_seed(0) .. GENERATED FROM PYTHON SOURCE LINES 28-33 Setup ----- First, define the physics that we want to train on. .. GENERATED FROM PYTHON SOURCE LINES 33-36 .. code-block:: Python physics = dinv.physics.Inpainting((1, 64, 64), mask=0.8, device=device, rng=rng) .. GENERATED FROM PYTHON SOURCE LINES 37-43 Then define the dataset. Here we simulate a dataset of measurements from Urban100. .. tip:: See :ref:`datasets ` for types of datasets DeepInverse supports: e.g. paired, ground-truth-free, single-image... .. GENERATED FROM PYTHON SOURCE LINES 43-72 .. code-block:: Python from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop, Grayscale dataset = dinv.datasets.Urban100HR( ".", download=True, transform=Compose([ToTensor(), Grayscale(), Resize(256), CenterCrop(64)]), ) train_dataset, test_dataset = torch.utils.data.random_split( torch.utils.data.Subset(dataset, range(50)), (0.8, 0.2) ) dataset_path = dinv.datasets.generate_dataset( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, device=device, save_dir=".", batch_size=1, ) train_dataloader = torch.utils.data.DataLoader( dinv.datasets.HDF5Dataset(dataset_path, train=True), shuffle=True ) test_dataloader = torch.utils.data.DataLoader( dinv.datasets.HDF5Dataset(dataset_path, train=False), shuffle=False ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/135388067 [00:00` for all supported state-of-the-art loss functions. We evaluate using the PSNR metric. See :ref:`metrics ` for all supported metrics. .. note:: In this example, we only train for a few epochs to keep the training time short. For a good reconstruction quality, we recommend to train for at least 100 epochs. .. GENERATED FROM PYTHON SOURCE LINES 110-129 .. code-block:: Python trainer = dinv.Trainer( model=model, physics=physics, optimizer=torch.optim.Adam(model.parameters(), lr=1e-3), train_dataloader=train_dataloader, eval_dataloader=test_dataloader, epochs=5, losses=dinv.loss.SupLoss(metric=dinv.metric.MSE()), metrics=dinv.metric.PSNR(), device=device, plot_images=True, show_progress_bar=False, ) _ = trainer.train() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_002.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_003.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_004.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_005.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_006.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_006.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_007.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_007.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_008.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_008.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_009.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_009.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_010.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_010.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_011.png :alt: Ground truth, Measurement, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_011.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none The model has 443585 trainable parameters Train epoch 0: TotalLoss=0.027, PSNR=17.103 Eval epoch 0: PSNR=22.209 Best model saved at epoch 1 Train epoch 1: TotalLoss=0.004, PSNR=24.858 Eval epoch 1: PSNR=27.873 Best model saved at epoch 2 Train epoch 2: TotalLoss=0.002, PSNR=28.266 Eval epoch 2: PSNR=25.191 Train epoch 3: TotalLoss=0.002, PSNR=28.654 Eval epoch 3: PSNR=30.454 Best model saved at epoch 4 Train epoch 4: TotalLoss=0.001, PSNR=30.681 Eval epoch 4: PSNR=31.838 Best model saved at epoch 5 .. GENERATED FROM PYTHON SOURCE LINES 130-135 Test the network -------------------------------------------- We can now test the trained network using the :func:`deepinv.test` function. The testing function will compute metrics and plot and save the results. .. GENERATED FROM PYTHON SOURCE LINES 135-137 .. code-block:: Python trainer.test(test_dataloader) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_training_012.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/models/images/sphx_glr_demo_training_012.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Eval epoch 0: PSNR=31.838, PSNR no learning=13.35 Test results: PSNR no learning: 13.350 +- 2.000 PSNR: 31.838 +- 2.291 {'PSNR no learning': np.float64(13.350251770019531), 'PSNR no learning_std': np.float64(2.000165115581808), 'PSNR': np.float64(31.8384765625), 'PSNR_std': np.float64(2.2912631323438837)} .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 17.918 seconds) .. _sphx_glr_download_auto_examples_models_demo_training.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_training.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_training.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_training.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_