Kohler#
- class deepinv.datasets.Kohler(root: str | Path, frames: int | str | list[int | str] = 'middle', ordering: str = 'printout_first', transform: Callable | None = None, download: bool = False)[source]#
Bases:
Dataset
Dataset for Recording and Playback of Camera Shake.
The dataset consists of blurry shots and sharp frames, each blurry shot being associated with about 200 sharp frames. There are 48 blurry shots in total, each associated to one of 4 printouts, and to one of 12 camera trajectories inducing motion blur. Unlike certain deblurring datasets (e.g. GOPRO) where the blurry images are synthesized from sharp images, the blurry shots in the Köhler dataset are acquired with a real camera. It is the movement of the camera during exposition that causes the blur. What we call printouts are the 4 images that were printed out on paper and fixed to a screen to serve as photographed subjects — all images in the dataset show one of these 4 printouts.
The ground truth images are not the 4 images that were printed out. Instead, they are the frames of videos taken in the same condition as for the blurry shots. The reason behind this choice is to ensure the same lightness for better comparison. In total, there are about 200 frames per video, and equivalently by blurry shot. There is a lot of redundancy between the frames as the camera barely moves between consecutive frames, for this reason the implementation allows selecting a single frame as the priviledged ground truth. This enables using the tooling provided by deepinv such as
deepinv.test()
and which gives approximately the same performance as comparing to all the frames. It is the parameterframes
that controls this behavior, when it is set to either"first"
,"middle"
,"last"
, or to a specific frame index (between 1 and 198). If the user wants to compare against all the frames, e.g. to reproduce the benchmarks of the original paper, they can do so by setting the parameterframes
to"all"
or to a list of frame indices.The dataset does not have a preferred ordering and this implementation uses lexicographic ordering on the printout index (1 to 4) and the trajectory index (1 to 12). The parameter
ordering
controls whether to order by printout first"printout_first"
or by trajectory first"trajectory_first"
. This enables accessing the 48 items using the standard method__getitem__
using an index between 0 and 47. The nonstandard methodget_item
allows selecting one of them by printout and trajectory index directly if needed.- Parameters:
frames (Union[int, str, list[Union[int, str]]]) – Can be the frame number,
"first"
,"middle"
,"last"
, or"all"
. If a list is provided, the method will return a list of sharp frames.ordering (str) – Ordering of the dataset. Can be
"printout_first"
or"trajectory_first"
.root (Union[str, Path]) – Root directory of the dataset.
transform (callable, optional) – A function used to transform both the blurry shots and the sharp frames.
download (bool) – Download the dataset.
- Examples:
Download the dataset and load one of its elements
from deepinv.datasets import Kohler dataset = Kohler(root="datasets/Kohler", frames="middle", ordering="printout_first", download=True) # Usual interface sharp_frame, blurry_shot = dataset[0] print(sharp_frame.shape, blurry_shot.shape) # Convenience method to directly index the printouts and trajectories sharp_frame, blurry_shot = dataset.get_item(1, 1, frames="middle") print(sharp_frame.shape, blurry_shot.shape)
- classmethod download(root: str | Path, remove_finished: bool = False) None [source]#
Download the dataset.
- Parameters:
- Examples:
Download the dataset
from deepinv.datasets import Kohler Kohler.download("datasets/Kohler")
- get_item(printout_index: int, trajectory_index: int, frames: None | int | str | list[int | str] = None) tuple[Tensor, Tensor] [source]#
Get a sharp frame and a blurry shot from the dataset.
- Parameters:
printout_index (int) – Index of the printout.
trajectory_index (int) – Index of the trajectory.
frames (Union[None, int, str, list[Union[int, str]]]) – Can be the frame number, “first”, “middle”, “last”, or “all”. If a list is provided, the method will return a list of sharp frames. By default, it uses the value provided in the constructor.
- Returns:
(torch.Tensor, Union[torch.Tensor, list[torch.Tensor]]) The sharp frame(s) and the blurry shot.
- Examples:
Get the first (middle) sharp frame and blurry shot
sharp_frame, blurry_shot = dataset.get_item(1, 1, frame="middle")
Get the list of all sharp frames and the blurry shot
sharp_frames, blurry_shot = dataset.get_item(1, 1, frame="all")
Query a list of specific frames and the blurry shot
sharp_frames, blurry_shot = dataset.get_item(1, 1, frame=[1, "middle", 199])