Source code for deepinv.physics.range

import torch
from deepinv.physics.forward import DecomposablePhysics


[docs] class Decolorize(DecomposablePhysics): r""" Converts n-channel images to grayscale. The image channels are multiplied by factors determined by the spectral response function (SRF), then summed to produce a grayscale image. We provide various ways of defining the SRF including the `rec601 <https://en.wikipedia.org/wiki/Rec._601>`_ convention for RGB images. In the adjoint operation, we multiply the grayscale image by the coefficients in the SRF. Images must be tensors with C channels, i.e. ``(B,C,H,W)``. The measurements are grayscale images. :param int channels: number of channels in the input image. :param str, tuple, list srf: spectral response function. Either pass in user-defined SRF (must be of length channels), or ``rec601`` (default) following the `rec601 <https://en.wikipedia.org/wiki/Rec._601>`_ convention, or ``flat`` for a flat SRF (i.e. averages channels), or ``random`` for random SRF (e.g. to initialise joint learning). :param str, torch.device device: device on which to perform the computations. Default: ``cpu``. |sep| :Examples: Decolorize a 3x3 image: >>> import torch >>> from deepinv.physics import Decolorize >>> x = torch.ones((1, 3, 3, 3), requires_grad=False) # 3x3 RGB image >>> physics = Decolorize() >>> physics(x) tensor([[[[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]]]]) """ def __init__(self, channels=3, srf="rec601", device="cpu", **kwargs): super().__init__(**kwargs) if srf is None or srf == "rec601": srf = torch.tensor([0.4472 * 0.66851, 0.8781 * 0.66851, 0.1706 * 0.66851]) elif srf in ("average", "flat"): srf = torch.tensor([1 / channels] * channels) elif srf == "random": srf = torch.rand(channels, device=device) srf /= srf.sum() elif isinstance(srf, (tuple, list)): srf = torch.tensor(srf) else: raise ValueError("Invalid srf") if srf.size(0) < channels: # pad with zeros srf = torch.cat([srf, torch.tensor([0] * (channels - srf.size(0)))]) elif srf.size(0) > channels: raise ValueError("srf should be of length equal to or less than channels.") assert torch.allclose(srf.sum(), torch.tensor(1.0), rtol=1e-4) self.register_buffer("srf", srf.to(device).view(1, channels, 1, 1)) self.register_buffer( "mask", torch.linalg.vector_norm(self.srf, dim=1, keepdim=True) ) self.to(device) def V_adjoint(self, x): if x.shape[1] != self.srf.shape[1]: raise ValueError("x should have same number of channels as SRF.") return torch.sum(x * self.srf / self.mask, dim=1, keepdim=True) def V(self, y): if y.shape[1] != 1: raise ValueError( "y should be grayscale i.e. have length 1 in the 1st dimension." ) return ( y.expand(y.shape[0], self.srf.shape[1], *y.shape[2:]) * self.srf / self.mask )