PatchDataset#

class deepinv.datasets.PatchDataset(imgs, patch_size=6, stride=1, transform=None, shape=(-1,))[source]#

Bases: TiledMixin2d, ImageDataset

Builds the dataset of all patches from a tensor of images.

Parameters:
  • imgs (torch.Tensor) – Tensor of images of shape (B, C, H, W).

  • patch_size (int | tuple[int, int]) – size of patches to extract. If int, the same value is used for height and width.

  • stride (int | tuple[int, int]) – stride between patches. If int, the same value is used for height and width.

  • transform (Callable) – data augmentation. A callable object, set to None for no augmentation.

  • shape (tuple) – shape of the returned tensor. If None, returns (C, h, w) where h and w are height and width of the patch. The default shape is (-1,) (flatten).

Examples using PatchDataset:#

Patch priors for limited-angle computed tomography

Patch priors for limited-angle computed tomography