Source code for deepinv.transform.reflect

from typing import Union, Iterable
import torch
from torchvision.transforms.functional import rotate
from torchvision.transforms import InterpolationMode
import numpy as np
from deepinv.transform.base import Transform, TransformParam
import itertools


[docs] class Reflect(Transform): r""" Reflect (flip) in random multiple axes. Generates ``n_trans`` reflected images, each time subselecting axes from dim (without replacement). Hence to transform through all group elements, set ``n_trans`` to ``2**len(dim)`` e.g ``Reflect(dim=[-2, -1], n_trans=4)`` See :class:`deepinv.transform.Transform` for further details and examples. :param int, list[int] dim: axis or axes on which to randomly select axes to reflect. :param int n_trans: number of transformed versions generated per input image. :param torch.Generator rng: random number generator, if None, use torch.Generator(), defaults to None """ def __init__( self, *args, dim: Union[int, list[int]] = [-2, -1], **kwargs, ): super().__init__(*args, **kwargs) self.dim = dim def _get_params(self, x: torch.Tensor) -> dict: """Randomly generate sets of reflection axes without replacement. :param torch.Tensor x: input image :return dict: keyword args with dims = tensor of which axes to flip, one row per n_trans, padded with nans. """ subsets = list( itertools.chain.from_iterable( itertools.combinations(self.dim, r) for r in range(len(self.dim) + 1) ) ) idx = torch.randperm(len(subsets), generator=self.rng, device=self.rng.device)[ : self.n_trans ] out = torch.full( (self.n_trans, len(self.dim)), fill_value=float("nan"), device=x.device ) for i, id in enumerate(idx): out[i, : len(subsets[id])] = torch.tensor(subsets[id], dtype=torch.int) return {"dims": TransformParam(out, neg=lambda x: x)} def _transform( self, x: torch.Tensor, dims: Union[torch.Tensor, Iterable] = [], **kwargs, ) -> torch.Tensor: """Reflect image in axes given in dim. :param torch.Tensor x: input image of shape (B,C,H,W) :param torch.Tensor, list dims: tensor with n_trans rows of axes to subselect for each reflected image. NaN axes are ignored. :return: torch.Tensor: transformed images. """ dims = [dim[~torch.isnan(dim)].int().tolist() for dim in dims] return torch.cat( [torch.flip(x, dims=dim) if len(dim) > 0 else x for dim in dims] )