.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_custom_dataset.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_custom_dataset.py: Bring your own dataset ======================= This example shows how to use DeepInverse with your own dataset. A dataset in DeepInverse can consist of optional ground-truth images `x`, measurements `y`, or :ref:`physics parameters ` `params`, or any combination of these. See :ref:`datasets user guide ` for the formats we expect data to be returned in for compatibility with DeepInverse (e.g., to be used with :class:`deepinv.Trainer`). DeepInverse provides multiple ways of bringing your own dataset. This example has two parts: firstly how to load images/data into a dataset, and secondly how to use this dataset with DeepInverse. .. GENERATED FROM PYTHON SOURCE LINES 16-22 .. code-block:: Python import deepinv as dinv import torch from torch.utils.data import DataLoader from torchvision.transforms import ToTensor .. GENERATED FROM PYTHON SOURCE LINES 23-25 Part 1: Loading data into a dataset ----------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-34 You have a folder of ground truth images ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Here we imagine we have a folder with one ground truth image of a butterfly. .. tip:: :class:`deepinv.datasets.ImageFolder` can load any type of data (e.g. MRI, CT, etc.) by passing in a custom `loader` function and `transform`. .. GENERATED FROM PYTHON SOURCE LINES 34-45 .. code-block:: Python DATA_DIR = dinv.utils.demo.get_data_home() / "demo_custom_dataset" dinv.utils.download_example("butterfly.png", DATA_DIR / "GT") dataset1 = dinv.datasets.ImageFolder(DATA_DIR / "GT", transform=ToTensor()) # Load one image from dataset x = next(iter(DataLoader(dataset1))) dinv.utils.plot({"x": x}) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_001.png :alt: x :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 46-50 You have a folder of paired ground truth and measurements ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now imagine we have a ground truth folder with a butterfly, and a measurements folder with a masked butterfly. .. GENERATED FROM PYTHON SOURCE LINES 50-61 .. code-block:: Python dinv.utils.download_example("butterfly_masked.png", DATA_DIR / "Measurements") dataset2 = dinv.datasets.ImageFolder( DATA_DIR, x_path="GT/*.png", y_path="Measurements/*.png", transform=ToTensor() ) x, y = next(iter(DataLoader(dataset2))) dinv.utils.plot({"x": x, "y": y}) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_002.png :alt: x, y :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 62-68 .. note:: If you're loading measurements which have randomly varying `params`, your dataset must return tuples `(x, y, params)` so that the physics is modified accordingly every image. We provide a convenience argument `ImageFolder(estimate_params=...)` to help you estimate these `params` on the fly. .. GENERATED FROM PYTHON SOURCE LINES 70-73 You have a folder of only measurements ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Imagine you have no ground truth, only measurements. Then `x` should be loaded in as NaN: .. GENERATED FROM PYTHON SOURCE LINES 73-81 .. code-block:: Python dataset3 = dinv.datasets.ImageFolder( DATA_DIR, y_path="Measurements/*.png", transform=ToTensor() ) x, y = next(iter(DataLoader(dataset3))) print(x) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([nan], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 82-87 You already have tensors ~~~~~~~~~~~~~~~~~~~~~~~~ Sometimes you might already have tensor(s). You can construct a dataset using :class:`deepinv.datasets.TensorDataset`, for example here an unsupervised dataset containing just a single measurement (and will be loaded in as a tuple `(nan, y)`): .. GENERATED FROM PYTHON SOURCE LINES 87-95 .. code-block:: Python y = dinv.utils.load_example("butterfly_masked.png") dataset4 = dinv.datasets.TensorDataset(y=y) x, y = next(iter(DataLoader(dataset4))) print(x) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([nan], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 96-99 You already have a PyTorch dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Say you already have your own PyTorch dataset: .. GENERATED FROM PYTHON SOURCE LINES 99-111 .. code-block:: Python class MyDataset(torch.utils.data.Dataset): def __len__(self): return 1 def __getitem__(self, i): # Returns (x, y, params) return torch.zeros(1), torch.zeros(1), {"mask": torch.zeros(1)} dataset5 = MyDataset() .. GENERATED FROM PYTHON SOURCE LINES 112-114 You should check that your dataset is compatible using :func:`deepinv.datasets.check_dataset` (alternatively inherit from :class:`deepinv.datasets.ImageDataset` and use `self.check_dataset()`): .. GENERATED FROM PYTHON SOURCE LINES 114-117 .. code-block:: Python dinv.datasets.check_dataset(dataset5) .. GENERATED FROM PYTHON SOURCE LINES 118-120 Part 2: Using your dataset with DeepInverse ------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 122-123 Say you have a DeepInverse problem already set up: .. GENERATED FROM PYTHON SOURCE LINES 123-128 .. code-block:: Python device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" physics = dinv.physics.Inpainting(img_size=(3, 256, 256)) model = dinv.models.RAM(pretrained=True, device=device) .. GENERATED FROM PYTHON SOURCE LINES 129-139 If your dataset already returns measurements in the form `(x, y)` or `(x, y, params)`, you can directly test with it. Our physics does not yet know the `params` (here, the inpainting mask). Since it is fixed across the dataset, we can define it manually by estimating it from y: .. note:: If you're loading measurements which have randomly varying `params`, your dataset must return tuples `(x, y, params)` so that the physics is modified accordingly every image. .. GENERATED FROM PYTHON SOURCE LINES 139-145 .. code-block:: Python params = {"mask": (dataset2[0][1].to(device) != 0).float()} physics.update(**params) dinv.test(model, DataLoader(dataset2), physics, plot_images=True, device=device) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_003.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/1 [00:00` to simulate random physics and then use `load_physics_generator_params=True` to load these `params` alongside the data during testing. .. GENERATED FROM PYTHON SOURCE LINES 186-188 If you don't want to generate a dataset offline, you can also generate measurements online ("on-the-fly") during testing or training: .. GENERATED FROM PYTHON SOURCE LINES 188-198 .. code-block:: Python dinv.test( model, DataLoader(dataset1), physics, plot_images=True, device=device, online_measurements=True, ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_006.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/basics/images/sphx_glr_demo_custom_dataset_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/1 [00:00` with your new dataset. * Check out the :ref:`example on how to fine-tune a foundation model ` to your own data. * Check out the :ref:`example on how to train a reconstruction model ` with your dataset. * Advanced: how to :ref:`stream or download a dataset from HuggingFace `. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 14.687 seconds) .. _sphx_glr_download_auto_examples_basics_demo_custom_dataset.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_custom_dataset.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_custom_dataset.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_custom_dataset.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_