patchify#

deepinv.models.utils.patchify(x, patch_size, stride=1)[source]#

Patchifying images.

This function takes in a batch of images and extracts overlapping patches of specified size and stride, returning them in a format suitable for processing by patch-based models.

Parameters:
  • x (torch.Tensor) – input image

  • patch_size ((int, int)) – patch size

  • stride (int) – stride

Returns:

(torch.Tensor) patched image of shape (B, C, patch_size, patch_size, num_pch)

Return type:

Tensor


Examples:

>>> import deepinv as dinv
>>> x = dinv.utils.load_example('butterfly.png')
>>> patches = dinv.models.utils.patchify(x, patch_size=8, stride=4)
>>> print(f"Input shape: {x.shape}, patchified shape: {patches.shape}")
Input shape: torch.Size([1, 3, 256, 256]), patchified shape: torch.Size([1, 3, 8, 8, 3969])
../../_images/deepinv-models-utils-patchify-1.png