Bring your own dataset#

This example shows how to use DeepInverse with your own dataset.

A dataset in DeepInverse can consist of optional ground-truth images x, measurements y, or physics parameters params, or any combination of these.

See datasets user guide for the formats we expect data to be returned in for compatibility with DeepInverse (e.g., to be used with deepinv.Trainer).

DeepInverse provides multiple ways of bringing your own dataset. This example has two parts: firstly how to load images/data into a dataset, and secondly how to use this dataset with DeepInverse.

import deepinv as dinv
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

Part 1: Loading data into a dataset#

You have a folder of ground truth images#

Here we imagine we have a folder with one ground truth image of a butterfly.

Tip

deepinv.datasets.ImageFolder can load any type of data (e.g. MRI, CT, etc.) by passing in a custom loader function and transform.

DATA_DIR = dinv.utils.demo.get_data_home() / "demo_custom_dataset"
dinv.utils.download_example("butterfly.png", DATA_DIR / "GT")

dataset1 = dinv.datasets.ImageFolder(DATA_DIR / "GT", transform=ToTensor())

# Load one image from dataset
x = next(iter(DataLoader(dataset1)))

dinv.utils.plot({"x": x})
x

You have a folder of paired ground truth and measurements#

Now imagine we have a ground truth folder with a butterfly, and a measurements folder with a masked butterfly.

dinv.utils.download_example("butterfly_masked.png", DATA_DIR / "Measurements")

dataset2 = dinv.datasets.ImageFolder(
    DATA_DIR, x_path="GT/*.png", y_path="Measurements/*.png", transform=ToTensor()
)

x, y = next(iter(DataLoader(dataset2)))

dinv.utils.plot({"x": x, "y": y})
x, y

Note

If you’re loading measurements which have randomly varying params, your dataset must return tuples (x, y, params) so that the physics is modified accordingly every image. We provide a convenience argument ImageFolder(estimate_params=...) to help you estimate these params on the fly.

You have a folder of only measurements#

Imagine you have no ground truth, only measurements. Then x should be loaded in as NaN:

dataset3 = dinv.datasets.ImageFolder(
    DATA_DIR, y_path="Measurements/*.png", transform=ToTensor()
)

x, y = next(iter(DataLoader(dataset3)))
print(x)
tensor([nan], dtype=torch.float64)

You already have tensors#

Sometimes you might already have tensor(s). You can construct a dataset using deepinv.datasets.TensorDataset, for example here an unsupervised dataset containing just a single measurement (and will be loaded in as a tuple (nan, y)):

y = dinv.utils.load_example("butterfly_masked.png")

dataset4 = dinv.datasets.TensorDataset(y=y)

x, y = next(iter(DataLoader(dataset4)))
print(x)
tensor([nan], dtype=torch.float64)

You already have a PyTorch dataset#

Say you already have your own PyTorch dataset:

class MyDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 1

    def __getitem__(self, i):  # Returns (x, y, params)
        return torch.zeros(1), torch.zeros(1), {"mask": torch.zeros(1)}


dataset5 = MyDataset()

You should check that your dataset is compatible using deepinv.datasets.check_dataset() (alternatively inherit from deepinv.datasets.ImageDataset and use self.check_dataset()):

dinv.datasets.check_dataset(dataset5)

Part 2: Using your dataset with DeepInverse#

Say you have a DeepInverse problem already set up:

If your dataset already returns measurements in the form (x, y) or (x, y, params), you can directly test with it.

Our physics does not yet know the params (here, the inpainting mask). Since it is fixed across the dataset, we can define it manually by estimating it from y:

Note

If you’re loading measurements which have randomly varying params, your dataset must return tuples (x, y, params) so that the physics is modified accordingly every image.

params = {"mask": (dataset2[0][1].to(device) != 0).float()}
physics.update(**params)

dinv.test(model, DataLoader(dataset2), physics, plot_images=True, device=device)
Ground truth, Measurement, No learning, Reconstruction
  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                                                     | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                   | 0/1 [00:02<?, ?it/s, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.93s/it, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.93s/it, PSNR=31.9, PSNR no learning=9.56]
Test results:
PSNR no learning: 9.561 +- 0.000
PSNR: 31.915 +- 0.000

{'PSNR no learning': np.float64(9.560781478881836), 'PSNR no learning_std': 0, 'PSNR': np.float64(31.915376663208008), 'PSNR_std': 0}

Even if the dataset doesn’t have ground truth:

Here reference-metrics such as PSNR will give NaN due to lack of ground truth, but no-reference metrics can be used.

Measurement, No learning, Reconstruction
Downloading: "https://huggingface.co/chaofengc/IQA-PyTorch-Weights/resolve/main/niqe_modelparameters.mat" to /home/runner/.cache/torch/hub/pyiqa/niqe_modelparameters.mat


  0%|          | 0.00/8.15k [00:00<?, ?B/s]
100%|██████████| 8.15k/8.15k [00:00<00:00, 46.3MB/s]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                                                     | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                    | 0/1 [00:02<?, ?it/s, PSNR=nan, PSNR no learning=nan, NIQE=5.72, NIQE no learning=213]
Test: 100%|████████████████████████████████████████████| 1/1 [00:02<00:00,  2.95s/it, PSNR=nan, PSNR no learning=nan, NIQE=5.72, NIQE no learning=213]
Test: 100%|████████████████████████████████████████████| 1/1 [00:02<00:00,  2.95s/it, PSNR=nan, PSNR no learning=nan, NIQE=5.72, NIQE no learning=213]
Test results:
PSNR no learning: nan +- 0.000
PSNR: nan +- 0.000
NIQE no learning: 213.442 +- 0.000
NIQE: 5.720 +- 0.000

{'PSNR no learning': np.float64(nan), 'PSNR no learning_std': 0, 'PSNR': np.float64(nan), 'PSNR_std': 0, 'NIQE no learning': np.float64(213.4423065185547), 'NIQE no learning_std': 0, 'NIQE': np.float64(5.720289707183838), 'NIQE_std': 0}

Generating measurements#

If your dataset returns only ground-truth x, you can generate a dataset of measurements using deepinv.datasets.generate_dataset():

Ground truth, Measurement, No learning, Reconstruction
Dataset has been saved at datasets/demo_custom_dataset/measurements/dinv_dataset0.h5

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                                                     | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                   | 0/1 [00:02<?, ?it/s, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.92s/it, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.92s/it, PSNR=31.9, PSNR no learning=9.56]
Test results:
PSNR no learning: 9.561 +- 0.000
PSNR: 31.915 +- 0.000

{'PSNR no learning': np.float64(9.560781478881836), 'PSNR no learning_std': 0, 'PSNR': np.float64(31.915376663208008), 'PSNR_std': 0}

Tip

Pass in a physics generator to simulate random physics and then use load_physics_generator_params=True to load these params alongside the data during testing.

If you don’t want to generate a dataset offline, you can also generate measurements online (“on-the-fly”) during testing or training:

dinv.test(
    model,
    DataLoader(dataset1),
    physics,
    plot_images=True,
    device=device,
    online_measurements=True,
)
Ground truth, Measurement, No learning, Reconstruction
  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                                                     | 0/1 [00:00<?, ?it/s]
Test:   0%|                                                                                   | 0/1 [00:02<?, ?it/s, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.95s/it, PSNR=31.9, PSNR no learning=9.56]
Test: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.95s/it, PSNR=31.9, PSNR no learning=9.56]
Test results:
PSNR no learning: 9.561 +- 0.000
PSNR: 31.915 +- 0.000

{'PSNR no learning': np.float64(9.560781478881836), 'PSNR no learning_std': 0, 'PSNR': np.float64(31.915376663208008), 'PSNR_std': 0}

🎉 Well done, you now know how to use your own dataset with DeepInverse!

What’s next?#

Total running time of the script: (0 minutes 14.687 seconds)

Gallery generated by Sphinx-Gallery