.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_train_inpainting.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_train_inpainting.py: Training a reconstruction network. ==================================================================================================== This example shows how to train a simple reconstruction network for an image inpainting inverse problem. .. GENERATED FROM PYTHON SOURCE LINES 9-17 .. code-block:: Python import deepinv as dinv from torch.utils.data import DataLoader import torch from pathlib import Path from torchvision import transforms from deepinv.utils.demo import load_dataset .. GENERATED FROM PYTHON SOURCE LINES 18-21 Setup paths for data loading and results. -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 21-31 .. code-block:: Python BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurements" CKPT_DIR = BASE_DIR / "ckpts" # Set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 32-36 Load base image datasets and degradation operators. -------------------------------------------------------------------------------------------- In this example, we use the CBSD68 dataset for training and the set3c dataset for testing. We work with images of size 32x32 if no GPU is available, else 128x128. .. GENERATED FROM PYTHON SOURCE LINES 36-53 .. code-block:: Python operation = "inpainting" train_dataset_name = "CBSD68" test_dataset_name = "set3c" img_size = 128 if torch.cuda.is_available() else 32 test_transform = transforms.Compose( [transforms.CenterCrop(img_size), transforms.ToTensor()] ) train_transform = transforms.Compose( [transforms.RandomCrop(img_size), transforms.ToTensor()] ) train_dataset = load_dataset(train_dataset_name, train_transform) test_dataset = load_dataset(test_dataset_name, test_transform) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading datasets/CBSD68.zip 0%| | 0.00/19.8M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_train_inpainting.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_train_inpainting.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_