Source code for deepinv.datasets.lsdir

import hashlib
from typing import Callable
from types import MappingProxyType
import os

from deepinv.datasets.utils import (
    calculate_md5_for_folder,
    extract_tarball,
)
from deepinv.datasets.base import ImageFolder
import shutil


[docs] class LsdirHR(ImageFolder): """Dataset for `LSDIR <https://ofsoundof.github.io/lsdir-data/>`_. Published in :footcite:t:`li2023lsdir`. A large-scale dataset for image restoration tasks such as image super-resolution (SR), image denoising, JPEG deblocking, deblurring, and demosaicking, and real-world SR. **Raw data file structure:** :: self.root --- 0001000 --- 0000001.png | | | -- 0001000.png | ... | -- 0085000 --- 0084001.png | | | -- 0084991.png | | -- val1 --- HR --- val --- 0000001.png | -- X2 | | -- X3 -- 0000250.png | -- X4 .. warning:: Downloading this dataset requires ``huggingface-hub``. It is gated, please request access (https://huggingface.co/ofsoundof/LSDIR) and make sure you are logged in using ``hf auth login`` (CLI) or ``from huggingface_hub import login, login()``. :param str root: Root directory of dataset. Directory path from where we load and save the dataset. :param str mode: Select a split of the dataset between 'train' or 'val'. Default at 'train'. :param bool download: If ``True``, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default at False. :param Callable transform:: (optional) A function/transform that takes in a PIL image and returns a transformed version. E.g, ``torchvision.transforms.RandomCrop`` |sep| :Examples: Instantiate dataset and download raw data from the Internet :: from deepinv.datasets import LsdirHR val_dataset = LsdirHR(root="Lsdir", mode="val", download=True) # download raw data at root and load dataset print(val_dataset.verify_split_dataset_integrity()) # check that raw data has been downloaded correctly print(len(val_dataset)) # check that we have 250 images """ _archive_urls = MappingProxyType( { "train": [ "shard-00.tar.gz", "shard-01.tar.gz", "shard-02.tar.gz", "shard-03.tar.gz", "shard-04.tar.gz", "shard-05.tar.gz", "shard-06.tar.gz", "shard-07.tar.gz", "shard-08.tar.gz", "shard-09.tar.gz", "shard-10.tar.gz", "shard-11.tar.gz", "shard-12.tar.gz", "shard-13.tar.gz", "shard-14.tar.gz", "shard-15.tar.gz", "shard-16.tar.gz", ], "val": [ "val1.tar.gz", ], } ) # for integrity of downloaded data _checksums = MappingProxyType( { "train": "a83bdb97076d617e4965913195cc84d1", "val": "972ba478c530b76eb9404b038597f65f", } ) def __init__( self, root: str, mode: str = "train", download: bool = False, transform: Callable = None, ) -> None: self.root = root self.mode = mode if self.mode == "train": # train_folder_names = ['0001000', ..., '0085000'] train_folder_names = [str(i * 1000).zfill(7) for i in range(1, 86)] self.img_dirs = [ os.path.join(self.root, folder) for folder in train_folder_names ] elif self.mode == "val": self.img_dirs = [os.path.join(self.root, "val1", "HR", "val")] else: raise ValueError( f"Expected `train` or `val` values for `mode` argument, instead got `{self.mode}`" ) # download a split of the dataset, we check first that this split isn't already downloaded if download: # pragma: no cover try: from huggingface_hub import hf_hub_download except: raise RuntimeError( "To download LsdirHR, please install huggingface-hub, request access to https://huggingface.co/ofsoundof/LSDIR (this should be instant), and authenticate yourself. For more info, see the LsdirHR's documentation." ) if not os.path.isdir(self.root): os.makedirs(self.root) # if a folder image exists, we stop the download if any([os.path.exists(img_dir) for img_dir in self.img_dirs]): raise ValueError( f"The {self.mode} folders already exists, thus the download is aborted. Please set `download=False` OR remove `{self.img_dirs}`." ) for filename in self._archive_urls[self.mode]: # download tar file from HuggingFace and save it locally hf_hub_download( repo_id="ofsoundof/LSDIR", filename=filename, local_dir=os.path.join(self.root), cache_dir=os.path.join(self.root, ".cache/huggingface"), local_dir_use_symlinks=False, ) # Since LSDIR is relatively large, we want to avoid taking too much redundant disk space. shutil.rmtree(os.path.join(self.root, ".cache/huggingface")) # extract local tar file extract_tarball(os.path.join(self.root, filename), self.root) os.remove( os.path.join(self.root, filename) ) # Since LSDIR is relatively large, we want to avoid taking too much redundant disk space. if self.verify_split_dataset_integrity(): print("Dataset has been successfully downloaded.") else: raise ValueError("There is an issue with the data downloaded.") if not all( os.path.isdir(d) and os.listdir(d) for d in self.img_dirs ): # pragma: no cover raise RuntimeError("Data folder doesn't exist, please set `download=True`") # Initialize ImageFolder if mode == "val": super().__init__(self.root, x_path="val1/HR/val/*.png", transform=transform) else: # mode is train for sure, because of earlier check super().__init__( self.root, x_path="shard-[0-1][0-9]/**/*.png", transform=transform )
[docs] def verify_split_dataset_integrity(self) -> bool: """Verify the integrity and existence of the specified dataset split. The expected structure of the dataset directory is as follows: :: self.root --- 0001000 --- 0000001.png | | | -- 0001000.png | ... | -- 0085000 --- 0084001.png | | | -- 0084991.png | | -- val1 --- HR --- val --- 0000001.png | -- X2 | | -- X3 -- 0000250.png | """ root_dir_exist = os.path.isdir(self.root) if not root_dir_exist: # pragma: no cover return False md5_folders = hashlib.md5() for img_dir in self.img_dirs: md5_folders.update(calculate_md5_for_folder(img_dir).encode()) return md5_folders.hexdigest() == self._checksums[self.mode]