Source code for deepinv.datasets.lsdir

import hashlib
from typing import Any, Callable
import os

from PIL import Image
import torch

from deepinv.datasets.utils import (
    calculate_md5_for_folder,
    download_archive,
    extract_tarball,
)


[docs] class LsdirHR(torch.utils.data.Dataset): """Dataset for `LSDIR <https://data.vision.ee.ethz.ch/yawli/>`_. A large-scale dataset for image restoration tasks such as image super-resolution (SR), image denoising, JPEG deblocking, deblurring, and demosaicking, and real-world SR. **Raw data file structure:** :: self.root --- 0001000 --- 0000001.png | | | -- 0001000.png | ... | -- 0085000 --- 0084001.png | | | -- 0084991.png -- shard-00.tar.gz | ... -- shard-16.tar.gz | -- val1 --- HR --- val --- 0000001.png | -- X2 | | -- X3 -- 0000250.png | -- X4 -- val1.tar.gz .. warning:: The official site hosting the dataset is unavailable : https://data.vision.ee.ethz.ch/yawli/. Thus the download argument isn't working for now. :param str root: Root directory of dataset. Directory path from where we load and save the dataset. :param str mode: Select a split of the dataset between 'train' or 'val'. Default at 'train'. :param bool download: If ``True``, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default at False. :param callable, optional transform: A function/transform that takes in a PIL image and returns a transformed version. E.g, ``torchvision.transforms.RandomCrop`` |sep| :Examples: Instantiate dataset and download raw data from the Internet :: from deepinv.datasets import LsdirHR val_dataset = LsdirHR(root="Lsdir", mode="val", download=True) # download raw data at root and load dataset print(val_dataset.verify_split_dataset_integrity()) # check that raw data has been downloaded correctly print(len(val_dataset)) # check that we have 250 images """ archive_urls = { "train": { "shard-00.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-00.tar.gz", "shard-01.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-01.tar.gz", "shard-02.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-02.tar.gz", "shard-03.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-03.tar.gz", "shard-04.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-04.tar.gz", "shard-05.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-05.tar.gz", "shard-06.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-06.tar.gz", "shard-07.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-07.tar.gz", "shard-08.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-08.tar.gz", "shard-09.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-09.tar.gz", "shard-10.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-10.tar.gz", "shard-11.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-11.tar.gz", "shard-12.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-12.tar.gz", "shard-13.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-13.tar.gz", "shard-14.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-14.tar.gz", "shard-15.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-15.tar.gz", "shard-16.tar.gz": "https://data.vision.ee.ethz.ch/yawli/shard-16.tar.gz", }, "val": { "val1.tar.gz": "https://data.vision.ee.ethz.ch/yawli/val1.tar.gz", }, } # for integrity of downloaded data checksums = { "train": "a83bdb97076d617e4965913195cc84d1", "val": "972ba478c530b76eb9404b038597f65f", } def __init__( self, root: str, mode: str = "train", download: bool = False, transform: Callable = None, ) -> None: self.root = root self.mode = mode self.transform = transform if self.mode == "train": # train_folder_names = ['0001000', ..., '0085000'] train_folder_names = [str(i * 1000).zfill(7) for i in range(1, 86)] self.img_dirs = [ os.path.join(self.root, folder) for folder in train_folder_names ] elif self.mode == "val": self.img_dirs = [os.path.join(self.root, "val1", "HR", "val")] else: raise ValueError( f"Expected `train` or `val` values for `mode` argument, instead got `{self.mode}`" ) # download a split of the dataset, we check first that this split isn't already downloaded if download: raise ValueError( f"""The official site hosting the dataset is unavailable : https://data.vision.ee.ethz.ch/yawli/.\n Thus the download argument isn't working for now.""" ) if not os.path.isdir(self.root): os.makedirs(self.root) # if a folder image exists, we stop the download if any([os.path.exists(img_dir) for img_dir in self.img_dirs]): raise ValueError( f"The {self.mode} folders already exists, thus the download is aborted. Please set `download=False` OR remove `{self.img_dirs}`." ) for filename, url in self.archive_urls[self.mode].items(): # download tar file from the Internet and save it locally download_archive( url=url, save_path=os.path.join(self.root, filename), ) # extract local tar file extract_tarball(os.path.join(self.root, filename), self.root) if self.verify_split_dataset_integrity(): print("Dataset has been successfully downloaded.") else: raise ValueError("There is an issue with the data downloaded.") self.img_paths = [] for img_dir in self.img_dirs: try: self.img_paths.extend( [os.path.join(img_dir, fname) for fname in os.listdir(img_dir)] ) except FileNotFoundError: raise RuntimeError( "Data folder doesn't exist, please set `download=True`" ) self.img_paths = sorted(self.img_paths) def __len__(self) -> int: return len(self.img_paths) def __getitem__(self, idx: int) -> Any: img_path = self.img_paths[idx] # PIL Image img = Image.open(img_path) if self.transform is not None: img = self.transform(img) return img
[docs] def verify_split_dataset_integrity(self) -> bool: """Verify the integrity and existence of the specified dataset split. The expected structure of the dataset directory is as follows: :: self.root --- 0001000 --- 0000001.png | | | -- 0001000.png | ... | -- 0085000 --- 0084001.png | | | -- 0084991.png | -- val1 --- HR --- val --- 0000001.png | | | -- 0000250.png -- xxx """ root_dir_exist = os.path.isdir(self.root) if not root_dir_exist: return False md5_folders = hashlib.md5() for img_dir in self.img_dirs: md5_folders.update(calculate_md5_for_folder(img_dir).encode()) return md5_folders.hexdigest() == self.checksums[self.mode]