TiledMixin2d#
- class deepinv.utils.TiledMixin2d(patch_size, stride=None, pad_if_needed=True, *args, **kwargs)[source]#
Bases:
objectMixin base class for 2D tiled patch extraction and reconstruction. Provides methods to extract overlapping patches from images and reconstruct images from patches.
It also handles padding if necessary to ensure all patches have the same size. The patch extraction and reconstruction are implemented using PyTorchβs unfold and fold operations for efficiency.
- Parameters:
patch_size (int | tuple[int, int]) β Size of each patch (height, width) or single
intfor square patches.stride (int | tuple[int, int]) β Stride between adjacent patches (height, width). If a single
intis provided, it is used for both dimensions. Defaults to half the patch size.pad_if_needed (bool) β If
True, the image will be padded if necessary to ensure all patches have the same size. Defaults toTrue.
The following example demonstrates how to use the
TiledMixin2dto extract patches from an image and reconstruct the image from those patches.- Examples:
>>> import torch >>> from deepinv.utils.mixins import TiledMixin2d >>> # Create an image of shape (B, C, H, W) >>> B, C, H, W = 1, 1, 5, 5 >>> image = torch.arange(B * C * H * W, dtype=torch.float32).reshape(B, C, H, W) >>> print(image) tensor([[[[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.], [20., 21., 22., 23., 24.]]]])
>>> # Initialize the TiledMixin2d with patch size and stride >>> patch_size = (3, 3) >>> stride = (2, 2) >>> tiled_mixin = TiledMixin2d(patch_size=patch_size, stride=stride)
>>> # Extract patches from the image >>> patches = tiled_mixin.image_to_patches(image) >>> print("Extracted Patches Shape:", patches.shape) Extracted Patches Shape: torch.Size([1, 1, 2, 2, 3, 3]) >>> print(patches[..., 0, 0, :, :]) # Print the first patch for verification tensor([[[[ 0., 1., 2.], [ 5., 6., 7.], [10., 11., 12.]]]])
>>> # Reconstruct the image from the patches >>> reconstructed_image = tiled_mixin.patches_to_image(patches, img_size=(H, W)) >>> print("Reconstructed Image Shape:", reconstructed_image.shape) Reconstructed Image Shape: torch.Size([1, 1, 5, 5]) >>> print(reconstructed_image) tensor([[[[ 0., 1., 4., 3., 4.], [ 5., 6., 14., 8., 9.], [20., 22., 48., 26., 28.], [15., 16., 34., 18., 19.], [20., 21., 44., 23., 24.]]]]) >>> # Note that by default, the reconstructed image is not necessarily equal to the original image due to overlapping regions being summed. Setting `reduce_overlap="mean"` in `patches_to_image` will average the overlapping regions instead of summing, which can give a closer reconstruction to the original image. >>> reconstructed_image_mean = tiled_mixin.patches_to_image(patches, img_size=(H, W), reduce_overlap="mean") >>> print(reconstructed_image_mean) tensor([[[[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.], [20., 21., 22., 23., 24.]]]])
- get_num_patches(img_size)[source]#
- Get number of patches along height and width.
If
pad_if_neededisTrue, this will return the number of patches that can be extracted after padding the image to a compatible size.If
pad_if_neededisFalse, this will return the number of patches that can be extracted without padding, which may not cover the whole image.
- image_to_patches(image)[source]#
Split an image into overlapping patches.
The image will be padded if necessary to ensure all patches have the same size.
- Parameters:
image (torch.Tensor) β Input image tensor of shape
(B, C, H, W).- Returns:
Patches tensor of shape
(B, C, n_rows, n_cols, patch_h, patch_w).- Return type:
- patches_to_image(patches, img_size=None, reduce_overlap='sum')[source]#
Reconstruct an image from overlapping patches.
This is the inverse operation of
image_to_patches. Note that overlapping regions are summed. So the reconstructed image is not necessarily equal to the original image.- Parameters:
patches (torch.Tensor) β Patches tensor of shape
(B, C, n_rows, n_cols, patch_h, patch_w).img_size (tuple[int, int] | None) β Target output size (height, width). If provided, output is cropped to this size from the top-left corner.
reduce_overlap (str) β How to handle overlapping regions. Options are
"sum"or"mean".
- Returns:
Reconstructed image tensor of shape
(B, C, H, W).- Return type: