Source code for deepinv.datasets.kohler

import torch
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive
from PIL import Image

from urllib.parse import urlparse
from os.path import basename, join
from typing import Callable, Union
from pathlib import Path


def url_basename(url: str) -> str:
    parts = urlparse(url)
    path = parts.path
    return basename(path)


[docs] class Kohler(Dataset): """Dataset for `Recording and Playback of Camera Shake <https://doi.org/10.1007/978-3-642-33786-4_3>`_. The dataset consists of blurry shots and sharp frames, each blurry shot being associated with about 200 sharp frames. There are 48 blurry shots in total, each associated to one of 4 printouts, and to one of 12 camera trajectories inducing motion blur. Unlike certain deblurring datasets (e.g. GOPRO) where the blurry images are synthesized from sharp images, the blurry shots in the Köhler dataset are acquired with a real camera. It is the movement of the camera during exposition that causes the blur. What we call printouts are the 4 images that were printed out on paper and fixed to a screen to serve as photographed subjects — all images in the dataset show one of these 4 printouts. The ground truth images are **not** the 4 images that were printed out. Instead, they are the frames of videos taken in the same condition as for the blurry shots. The reason behind this choice is to ensure the same lightness for better comparison. In total, there are about 200 frames per video, and equivalently by blurry shot. There is a lot of redundancy between the frames as the camera barely moves between consecutive frames, for this reason the implementation allows selecting a single frame as the priviledged ground truth. This enables using the tooling provided by deepinv such as :func:`deepinv.test` and which gives approximately the same performance as comparing to all the frames. It is the parameter ``frames`` that controls this behavior, when it is set to either ``"first"``, ``"middle"``, ``"last"``, or to a specific frame index (between 1 and 198). If the user wants to compare against all the frames, e.g. to reproduce the benchmarks of the original paper, they can do so by setting the parameter ``frames`` to ``"all"`` or to a list of frame indices. The dataset does not have a preferred ordering and this implementation uses lexicographic ordering on the printout index (1 to 4) and the trajectory index (1 to 12). The parameter ``ordering`` controls whether to order by printout first ``"printout_first"`` or by trajectory first ``"trajectory_first"``. This enables accessing the 48 items using the standard method ``__getitem__`` using an index between 0 and 47. The nonstandard method ``get_item`` allows selecting one of them by printout and trajectory index directly if needed. :param Union[int, str, list[Union[int, str]]] frames: Can be the frame number, ``"first"``, ``"middle"``, ``"last"``, or ``"all"``. If a list is provided, the method will return a list of sharp frames. :param str ordering: Ordering of the dataset. Can be ``"printout_first"`` or ``"trajectory_first"``. :param Union[str, pathlib.Path] root: Root directory of the dataset. :param Callable transform:: (optional) A function used to transform both the blurry shots and the sharp frames. :param bool download: Download the dataset. |sep| :Examples: Download the dataset and load one of its elements :: from deepinv.datasets import Kohler dataset = Kohler(root="datasets/Kohler", frames="middle", ordering="printout_first", download=True) # Usual interface sharp_frame, blurry_shot = dataset[0] print(sharp_frame.shape, blurry_shot.shape) # Convenience method to directly index the printouts and trajectories sharp_frame, blurry_shot = dataset.get_item(1, 1, frames="middle") print(sharp_frame.shape, blurry_shot.shape) """ # The Köhler dataset is split into multiple archives available online. archive_urls = [ "http://people.kyb.tuebingen.mpg.de/rolfk/BenchmarkECCV2012/GroundTruth_pngs_Image1.zip", "http://people.kyb.tuebingen.mpg.de/rolfk/BenchmarkECCV2012/GroundTruth_pngs_Image2.zip", "http://people.kyb.tuebingen.mpg.de/rolfk/BenchmarkECCV2012/GroundTruth_pngs_Image3.zip", "http://people.kyb.tuebingen.mpg.de/rolfk/BenchmarkECCV2012/GroundTruth_pngs_Image4.zip", "http://people.kyb.tuebingen.mpg.de/rolfk/BenchmarkECCV2012/BlurryImages.zip", ] # The checksums are used to verify the integrity of the downloaded # archives. archive_checksums = { "GroundTruth_pngs_Image1.zip": "acb90b6d9bfdb4b2370e08a5fcb80e68", "GroundTruth_pngs_Image2.zip": "da440d3bf43b32bec0b7170ccd828f29", "GroundTruth_pngs_Image3.zip": "3a77c41c951367f35db52eb18496bbac", "GroundTruth_pngs_Image4.zip": "72ce9690c3ed1296358653396cf9576d", "BlurryImages.zip": "61ffb1434d93fca6c508976a7216d723", } # Most of the acquisitions of sharp images span exactly 199 frames but not # all of them and this lookup table gives each frame count for them all. frame_count_table = { (2, 11): 200, (1, 10): 198, (1, 12): 198, (2, 10): 198, (3, 7): 198, (3, 12): 198, (4, 12): 198, "others": 199, } def __init__( self, root: Union[str, Path], frames: Union[int, str, list[Union[int, str]]] = "middle", ordering: str = "printout_first", transform: Callable = None, download: bool = False, ) -> None: self.root = root self.frames = frames self.ordering = ordering self.transform = transform if download: self.download(self.root)
[docs] @classmethod def download(cls, root: Union[str, Path], remove_finished: bool = False) -> None: """Download the dataset. :param Union[str, pathlib.Path] root: Root directory of the dataset. :param bool remove_finished: Remove the archives after extraction. |sep| :Examples: Download the dataset :: from deepinv.datasets import Kohler Kohler.download("datasets/Kohler") """ for url in cls.archive_urls: archive_name = url_basename(url) checksum = cls.archive_checksums[archive_name] # Download the archive and verify its integrity download_and_extract_archive( url, root, filename=archive_name, md5=checksum, remove_finished=remove_finished, )
def __len__(self) -> int: return 48 def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: """Get a sharp frame and a blurry shot from the dataset. :param int index: Index of the pair. :return: (torch.Tensor, torch.Tensor) The sharp frame and the blurry shot. |sep| :Examples: Get the first sharp frame and blurry shot :: sharp_frame, blurry_shot = dataset[0] """ if self.ordering == "printout_first": printout_index = index // 12 + 1 trajectory_index = index % 12 + 1 elif self.ordering == "trajectory_first": printout_index = index % 12 + 1 trajectory_index = index // 12 + 1 else: raise ValueError(f"Unsupported ordering: {self.ordering}") return self.get_item(printout_index, trajectory_index, frames=self.frames) # While users might sometimes want to thoroughly compare their own # deblurred images to all the sharp frames (about 200 per blurry shot), # they will probably most often make the way more convenient choice of # comparing against a single frame per blurry shot. For this reason, the # method get_item accepts an additional parameter for frame selection and # only returns the selected frame.
[docs] def get_item( self, printout_index: int, trajectory_index: int, frames: Union[None, int, str, list[Union[int, str]]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Get a sharp frame and a blurry shot from the dataset. :param int printout_index: Index of the printout. :param int trajectory_index: Index of the trajectory. :param Union[None, int, str, list[Union[int, str]]] frames: Can be the frame number, "first", "middle", "last", or "all". If a list is provided, the method will return a list of sharp frames. By default, it uses the value provided in the constructor. :return: (torch.Tensor, Union[torch.Tensor, list[torch.Tensor]]) The sharp frame(s) and the blurry shot. |sep| :Examples: Get the first (middle) sharp frame and blurry shot :: sharp_frame, blurry_shot = dataset.get_item(1, 1, frame="middle") Get the list of all sharp frames and the blurry shot :: sharp_frames, blurry_shot = dataset.get_item(1, 1, frame="all") Query a list of specific frames and the blurry shot :: sharp_frames, blurry_shot = dataset.get_item(1, 1, frame=[1, "middle", 199]) """ blurry_shot = self.get_blurry_shot(printout_index, trajectory_index) if frames is None: frames = self.frames if frames == "all" or isinstance(frames, list): if frames == "all": frames = range( 1, self.get_frame_count(printout_index, trajectory_index) + 1 ) sharp_frames = [ self.get_sharp_frame(printout_index, trajectory_index, frame_index) for frame_index in frames ] return sharp_frames, blurry_shot else: frame_index = self.select_frame( printout_index, trajectory_index, frame=frames ) sharp_frame = self.get_sharp_frame( printout_index, trajectory_index, frame_index ) return sharp_frame, blurry_shot
def get_sharp_frame( printout_index: int, trajectory_index: int, frame_index: int ) -> Union[torch.Tensor, Image.Image, any]: path = join( self.root, f"Image{printout_index}", f"Kernel{trajectory_index}", f"GroundTruth{printout_index}_{trajectory_index}_{frame_index}.png", ) sharp_frame = Image.open(path) if self.transform is not None: sharp_frame = self.transform(sharp_frame) return sharp_frame def get_blurry_shot( printout_index: int, trajectory_index: int ) -> Union[torch.Tensor, Image.Image, any]: path = join(self.root, f"Blurry{printout_index}_{trajectory_index}.png") blurry_shot = Image.open(path) if self.transform is not None: blurry_shot = self.transform(blurry_shot) return blurry_shot @classmethod def select_frame( cls, printout_index: int, trajectory_index: int, frame: Union[int, str] ) -> int: if isinstance(frame, int): frame_index = frame else: frame_count = cls.get_frame_count(printout_index, trajectory_index) if frame == "first": frame_index = 1 elif frame == "middle": frame_index = (frame_count + 1) // 2 elif frame == "last": frame_index = frame_count else: raise ValueError(f"Unsupported frame selection: {frame}") return frame_index @classmethod def get_frame_count(cls, printout_index: int, trajectory_index: int) -> int: index = (printout_index, trajectory_index) if index in cls.frame_count_table: count = cls.frame_count_table[index] else: count = cls.frame_count_table["others"] return count