Source code for deepinv.datasets.div2k

from typing import Any, Callable
import os

from PIL import Image
import torch

from deepinv.datasets.utils import (
    calculate_md5_for_folder,
    download_archive,
    extract_zipfile,
)


[docs] class DIV2K(torch.utils.data.Dataset): """Dataset for `DIV2K Image Super-Resolution Challenge <https://data.vision.ee.ethz.ch/cvl/DIV2K>`_. Images have varying sizes with up to 2040 vertical pixels, and 2040 horizontal pixels. **Raw data file structure:** :: self.root --- DIV2K_train_HR --- 0001.png | | | -- 0800.png | -- DIV2K_valid_HR --- 0801.png | | | -- 0900.png -- DIV2K_train_HR.zip -- DIV2K_valid_HR.zip :param str root: Root directory of dataset. Directory path from where we load and save the dataset. :param str mode: Select a split of the dataset between 'train' or 'val'. Default at 'train'. :param bool download: If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Default at False. :param Callable transform:: (optional) A function/transform that takes in a PIL image and returns a transformed version. E.g, ``torchvision.transforms.RandomCrop`` |sep| :Examples: Instantiate dataset and download raw data from the Internet >>> import shutil >>> from deepinv.datasets import DIV2K >>> dataset = DIV2K(root="DIV2K", mode="val", download=True) # download raw data at root and load dataset Dataset has been successfully downloaded. >>> print(dataset.verify_split_dataset_integrity()) # check that raw data has been downloaded correctly True >>> print(len(dataset)) # check that we have 100 images 100 >>> shutil.rmtree("DIV2K") # remove raw data from disk """ # https://data.vision.ee.ethz.ch/cvl/DIV2K/ archive_urls = { "DIV2K_train_HR.zip": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip", "DIV2K_valid_HR.zip": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip", } # for integrity of downloaded data checksums = { "DIV2K_train_HR": "f9de9c251af455c1021017e61713a48b", "DIV2K_valid_HR": "542325e500b0a474c7ad18bae922da72", } def __init__( self, root: str, mode: str = "train", download: bool = False, transform: Callable = None, ) -> None: self.root = root self.mode = mode self.transform = transform if self.mode == "train": self.img_dir = os.path.join(self.root, "DIV2K_train_HR") elif self.mode == "val": self.img_dir = os.path.join(self.root, "DIV2K_valid_HR") else: raise ValueError( f"Expected `train` or `val` values for `mode` argument, instead got `{self.mode}`" ) # download a split of the dataset, we check first that this split isn't already downloaded if not self.verify_split_dataset_integrity(): if download: if not os.path.isdir(self.root): os.makedirs(self.root) if os.path.exists(self.img_dir): raise ValueError( f"The {self.mode} folder already exists, thus the download is aborted. Please set `download=False` OR remove `{self.img_dir}`." ) zip_filename = ( "DIV2K_train_HR.zip" if self.mode == "train" else "DIV2K_valid_HR.zip" ) # download zip file from the Internet and save it locally download_archive( url=self.archive_urls[zip_filename], save_path=os.path.join(self.root, zip_filename), ) # extract local zip file extract_zipfile(os.path.join(self.root, zip_filename), self.root) if self.verify_split_dataset_integrity(): print("Dataset has been successfully downloaded.") else: raise ValueError("There is an issue with the data downloaded.") # stop the execution since the split dataset is not available and we didn't download it else: raise RuntimeError( f"Dataset not found at `{self.root}`. Please set `root` correctly (currently `root={self.root}`), OR set `download=True` (currently `download={download}`)." ) self.img_list = os.listdir(self.img_dir) def __len__(self) -> int: return len(self.img_list) def __getitem__(self, idx: int) -> Any: img_path = os.path.join(self.img_dir, self.img_list[idx]) # PIL Image img = Image.open(img_path) if self.transform is not None: img = self.transform(img) return img
[docs] def verify_split_dataset_integrity(self) -> bool: """Verify the integrity and existence of the specified dataset split. This method checks if ``DIV2K_train_HR`` or ``DIV2K_valid_HR`` folder within ``self.root`` exists and validates the integrity of its contents by comparing the MD5 checksum of the folder with the expected checksum. The expected structure of the dataset directory is as follows: :: self.root --- DIV2K_train_HR --- 0001.png | | | -- 0800.png | -- DIV2K_valid_HR --- 0801.png | | | -- 0900.png -- xxx """ root_dir_exist = os.path.isdir(self.root) if not root_dir_exist: return False if self.mode == "train": return ( calculate_md5_for_folder(self.img_dir) == self.checksums["DIV2K_train_HR"] ) else: return ( calculate_md5_for_folder(self.img_dir) == self.checksums["DIV2K_valid_HR"] )