Datasets
This subpackage can be used for generating reconstruction datasets from other base datasets (e.g. MNIST or CelebA).
HD5Dataset
DeepInverse HDF5 dataset with signal/measurement pairs |
|
Generates dataset of signal/measurement pairs from base dataset. |
Generating a dataset associated with a certain forward operator is done via deepinv.datasets.generate_dataset()
using a base PyTorch dataset (torch.utils.data.Dataset
, in this case MNIST). For example, here we generate a compressed sensing MNIST dataset:
Note
We support all data types supported by h5py
, including complex numbers.
>>> import deepinv as dinv
>>> from torchvision import datasets, transforms
>>>
>>> save_dir = '../datasets/MNIST/' # directory where the dataset will be saved.
>>>
>>> # define base train dataset
>>> transform_data = transforms.Compose([transforms.ToTensor()])
>>> data_train = datasets.MNIST(root='../datasets/', train=True,
... transform=transform_data, download=True)
>>> data_test = datasets.MNIST(root='../datasets/', train=False, transform=transform_data)
>>>
>>> # define forward operator
>>> physics = dinv.physics.CompressedSensing(m=300, img_shape=(1, 28, 28))
>>> physics.noise_model = dinv.physics.GaussianNoise(sigma=.05)
>>>
>>> # generate paired dataset
>>> generated_dataset_path = dinv.datasets.generate_dataset(train_dataset=data_train, test_dataset=data_test,
... physics=physics, save_dir=save_dir, verbose=False)
Similarly, we can generate a dataset from a local folder of images (other types of data can be loaded using the loader
and is_valid_file
arguments of torchvision.datasets.ImageFolder`()
):
>>> # Note that ImageFolder requires file structure to be '.../dir/train/xxx/yyy.ext' where xxx is an arbitrary class label
>>> data_train = datasets.ImageFolder(f'{save_dir}/train', transform=transform_data)
>>> data_test = datasets.ImageFolder(f'{save_dir}/test', transform=transform_data)
>>>
>>> dinv.datasets.generate_dataset(train_dataset=data_train, test_dataset=data_test,
>>> physics=physics, device=dinv.device, save_dir=save_dir)
The datasets are saved in .h5
(HDF5) format, and can be easily loaded to pytorch’s standard
torch.utils.data.DataLoader
:
>>> from torch.utils.data import DataLoader
>>>
>>> dataset = dinv.datasets.HDF5Dataset(path=generated_dataset_path, train=True)
>>> dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
We can also use physics generators to randomly generate physics params for data, and save and load the physics params into the dataset:
>>> physics_generator = dinv.physics.generator.SigmaGenerator()
>>> pth = dinv.datasets.generate_dataset(train_dataset=data_train, test_dataset=data_test,
>>> physics=physics, physics_generator=physics_generator,
>>> device=dinv.device, save_dir=save_dir)
>>> dataset = dinv.datasets.HDF5Dataset(path=pth, load_physics_generator_params=True, train=True)
>>> dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
>>> x, y, params = next(iter(dataloader))
>>> print(params['sigma'].shape)
torch.Size([4])
PatchDataset
Generate a dataset of all patches out of a tensor of images.
Builds the dataset of all patches from a tensor of images. |
Image Datasets
Ready-made datasets available in the deepinv.datasets module.
Dataset for DIV2K Image Super-Resolution Challenge. |
|
Dataset for Urban100. |
|
Dataset for Set14. |
|
Dataset for CBSBD68. |
|
Dataset for fastMRI that provides access to MR image slices. |
|
Dataset for LIDC-IDRI that provides access to CT image slices. |
|
Dataset for Flickr2K. |
|
Dataset for LSDIR. |
|
Dataset for Fluorescence Microscopy Denoising. |