patchify#
- deepinv.models.utils.patchify(x, patch_size, stride=1)[source]#
Patchifying images.
This function takes in a batch of images and extracts overlapping patches of specified size and stride, returning them in a format suitable for processing by patch-based models.
- Parameters:
x (torch.Tensor) – input image
stride (int) – stride
- Returns:
(
torch.Tensor) patched image of shape (B, C, patch_size, patch_size, num_pch)- Return type:
- Examples:
>>> import deepinv as dinv >>> x = dinv.utils.load_example('butterfly.png') >>> patches = dinv.models.utils.patchify(x, patch_size=8, stride=4) >>> print(f"Input shape: {x.shape}, patchified shape: {patches.shape}") Input shape: torch.Size([1, 3, 256, 256]), patchified shape: torch.Size([1, 3, 8, 8, 3969])