Source code for deepinv.physics.blur

from torchvision.transforms.functional import rotate
import torchvision
import torch
import numpy as np
import torch.fft as fft
from torch import Tensor
from deepinv.physics.forward import LinearPhysics, DecomposablePhysics
from deepinv.physics.functional import (
    conv2d,
    conv_transpose2d,
    filter_fft_2d,
    product_convolution2d,
    product_convolution2d_adjoint,
    conv3d_fft,
    conv_transpose3d_fft,
)


[docs] class Downsampling(LinearPhysics): r""" Downsampling operator for super-resolution problems. It is defined as .. math:: y = S (h*x) where :math:`h` is a low-pass filter and :math:`S` is a subsampling operator. :param torch.Tensor, str, None filter: Downsampling filter. It can be ``'gaussian'``, ``'bilinear'``, ``'bicubic'`` , ``'sinc'`` or a custom ``torch.Tensor`` filter. If ``None``, no filtering is applied. :param tuple[int] img_size: size of the input image :param int factor: downsampling factor :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``. If ``padding='valid'`` the blurred output is smaller than the image (no padding) otherwise the blurred output has the same size as the image. |sep| :Examples: Downsampling operator with a gaussian filter: >>> from deepinv.physics import Downsampling >>> x = torch.zeros((1, 1, 32, 32)) # Define black image of size 32x32 >>> x[:, :, 16, 16] = 1 # Define one white pixel in the middle >>> physics = Downsampling(filter = "gaussian", img_size=(1, 32, 32), factor=2) >>> y = physics(x) >>> y[:, :, 7:10, 7:10] # Display the center of the downsampled image tensor([[[[0.0146, 0.0241, 0.0146], [0.0241, 0.0398, 0.0241], [0.0146, 0.0241, 0.0146]]]]) """ def __init__( self, img_size, filter=None, factor=2, device="cpu", padding="circular", **kwargs, ): super().__init__(**kwargs) assert isinstance(factor, int), "downsampling factor should be an integer" self.imsize = img_size self.padding = padding self.device = device self.register_buffer("filter", None) self.update_parameters(filter=filter, factor=factor, **kwargs) self.to(device)
[docs] def A(self, x, filter=None, factor=None, **kwargs): r""" Applies the downsampling operator to the input image. :param torch.Tensor x: input image. :param None, torch.Tensor filter: Filter :math:`h` to be applied to the input image before downsampling. If not ``None``, it uses this filter and stores it as the current filter. """ self.update_parameters(filter=filter, factor=factor, **kwargs) if self.filter is not None: x = conv2d(x, self.filter, padding=self.padding) x = x[:, :, :: self.factor, :: self.factor] # downsample return x
[docs] def A_adjoint(self, y, filter=None, factor=None, **kwargs): r""" Adjoint operator of the downsampling operator. :param torch.Tensor y: downsampled image. :param None, torch.Tensor filter: Filter :math:`h` to be applied to the input image before downsampling. If not ``None``, it uses this filter and stores it as the current filter. """ self.update_parameters(filter=filter, factor=factor, **kwargs) imsize = self.imsize if self.filter is not None: if self.padding == "valid": imsize = ( self.imsize[0], self.imsize[1] - self.filter.shape[-2] + 1, self.imsize[2] - self.filter.shape[-1] + 1, ) else: imsize = ( self.imsize[0], self.imsize[1], self.imsize[2], ) x = torch.zeros((y.shape[0],) + imsize, device=y.device, dtype=y.dtype) x[:, :, :: self.factor, :: self.factor] = y # upsample if self.filter is not None: x = conv_transpose2d( x, self.filter, padding=self.padding ) # Note: this may be slow against x = conv_transpose2d_fft(x, self.filter) in the case of circular padding return x
[docs] def prox_l2(self, z, y, gamma, use_fft=True, **kwargs): r""" If the padding is circular, it computes the proximal operator with the closed-formula of https://arxiv.org/abs/1510.00143. Otherwise, it computes it using the conjugate gradient algorithm which can be slow if applied many times. """ if use_fft and self.padding == "circular": # Formula from (Zhao, 2016) z_hat = self.A_adjoint(y) + 1 / gamma * z Fz_hat = fft.fft2(z_hat) def splits(a, sf): """split a into sfxsf distinct blocks Args: a: NxCxWxH sf: split factor Returns: b: NxCx(W/sf)x(H/sf)x(sf^2) """ b = torch.stack(torch.chunk(a, sf, dim=2), dim=4) b = torch.cat(torch.chunk(b, sf, dim=3), dim=4) return b top = torch.mean(splits(self.Fh * Fz_hat, self.factor), dim=-1) below = torch.mean(splits(self.Fh2, self.factor), dim=-1) + 1 / gamma rc = self.Fhc * (top / below).repeat(1, 1, self.factor, self.factor) r = torch.real(fft.ifft2(rc)) return (z_hat - r) * gamma else: return LinearPhysics.prox_l2(self, z, y, gamma, **kwargs)
[docs] def update_parameters(self, filter=None, factor=None, **kwargs): r""" Updates the current filter and/or factor. :param torch.Tensor filter: New filter to be applied to the input image. :param int factor: New downsampling factor to be applied to the input image. """ if factor is not None: if isinstance(factor, (int, float)): self.factor = int(factor) else: raise ValueError( f"Factor must be an integer, got {factor} of type {type(factor)}." ) if filter is not None: if isinstance(filter, torch.Tensor): filter = filter.to(self.device) elif filter == "gaussian": filter = gaussian_blur(sigma=(self.factor, self.factor)).to(self.device) elif filter == "bilinear": filter = bilinear_filter(self.factor).to(self.device) elif filter == "bicubic": filter = bicubic_filter(self.factor).to(self.device) elif filter == "sinc": filter = sinc_filter(self.factor, length=4 * self.factor).to( self.device ) self.register_buffer("filter", filter) if self.filter is not None: self.register_buffer( "Fh", filter_fft_2d(self.filter, self.imsize, real_fft=False).to(self.device), ) self.register_buffer("Fhc", torch.conj(self.Fh)) self.register_buffer("Fh2", self.Fhc * self.Fh) super().update_parameters(**kwargs)
[docs] class Blur(LinearPhysics): r""" Blur operator. This forward operator performs .. math:: y = w*x where :math:`*` denotes convolution and :math:`w` is a filter. :param torch.Tensor filter: Tensor of size (b, 1, h, w) or (b, c, h, w) in 2D; (b, 1, d, h, w) or (b, c, d, h, w) in 3D, containing the blur filter, e.g., :func:`deepinv.physics.blur.gaussian_blur`. :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``. If ``padding='valid'`` the blurred output is smaller than the image (no padding) otherwise the blurred output has the same size as the image. (default is ``'valid'``). Only ``padding='valid'`` and ``padding = 'circular'`` are implemented in 3D. :param str device: cpu or cuda. .. note:: This class makes it possible to change the filter at runtime by passing a new filter to the forward method, e.g., ``y = physics(x, w)``. The new filter :math:`w` is stored as the current filter. .. note:: This class uses the highly optimized :func:`torch.nn.functional.conv2d` for performing the convolutions in 2D and FFT for performing the convolutions in 3D as implemented in :func:`deepinv.physics.functional.conv3d_fft`. It uses FFT based convolutions in 3D since :func:`torch.nn.functional.conv3d` is slow for large kernels. |sep| :Examples: Blur operator with a basic averaging filter applied to a 16x16 black image with a single white pixel in the center: >>> from deepinv.physics import Blur >>> x = torch.zeros((1, 1, 16, 16)) # Define black image of size 16x16 >>> x[:, :, 8, 8] = 1 # Define one white pixel in the middle >>> w = torch.ones((1, 1, 2, 2)) / 4 # Basic 2x2 averaging filter >>> physics = Blur(filter=w) >>> y = physics(x) >>> y[:, :, 7:10, 7:10] # Display the center of the blurred image tensor([[[[0.2500, 0.2500, 0.0000], [0.2500, 0.2500, 0.0000], [0.0000, 0.0000, 0.0000]]]]) """ def __init__(self, filter=None, padding="valid", device="cpu", **kwargs): super().__init__(**kwargs) self.device = device self.padding = padding assert ( isinstance(filter, Tensor) or filter is None ), f"The filter must be a torch.Tensor or None, got filter of type {type(filter)}." self.register_buffer("filter", filter) self.to(device)
[docs] def A(self, x, filter=None, **kwargs): r""" Applies the filter to the input image. :param torch.Tensor x: input image. :param torch.Tensor filter: Filter :math:`w` to be applied to the input image. If not ``None``, it uses this filter instead of the one defined in the class, and the provided filter is stored as the current filter. """ self.update_parameters(filter=filter, **kwargs) if x.dim() == 4: return conv2d(x, filter=self.filter, padding=self.padding) elif x.dim() == 5: return conv3d_fft(x, filter=self.filter, padding=self.padding)
[docs] def A_adjoint(self, y, filter=None, **kwargs): r""" Adjoint operator of the blur operator. :param torch.Tensor y: blurred image. :param torch.Tensor filter: Filter :math:`w` to be applied to the input image. If not ``None``, it uses this filter instead of the one defined in the class, and the provided filter is stored as the current filter. """ self.update_parameters(filter=filter, **kwargs) if y.dim() == 4: return conv_transpose2d(y, filter=self.filter, padding=self.padding) elif y.dim() == 5: return conv_transpose3d_fft(y, filter=self.filter, padding=self.padding)
[docs] class BlurFFT(DecomposablePhysics): """ FFT-based blur operator. It performs the operation .. math:: y = w*x where :math:`*` denotes convolution and :math:`w` is a filter. Blur operator based on ``torch.fft`` operations, which assumes a circular padding of the input, and allows for the singular value decomposition via ``deepinv.Physics.DecomposablePhysics`` and has fast pseudo-inverse and prox operators. :param tuple img_size: Input image size in the form (C, H, W). :param torch.Tensor filter: torch.Tensor of size (1, c, h, w) containing the blur filter with h<=H, w<=W and c=1 or c=C e.g., :func:`deepinv.physics.blur.gaussian_blur`. :param str device: cpu or cuda |sep| :Examples: BlurFFT operator with a basic averaging filter applied to a 16x16 black image with a single white pixel in the center: >>> from deepinv.physics import BlurFFT >>> x = torch.zeros((1, 1, 16, 16)) # Define black image of size 16x16 >>> x[:, :, 8, 8] = 1 # Define one white pixel in the middle >>> filter = torch.ones((1, 1, 2, 2)) / 4 # Basic 2x2 filter >>> physics = BlurFFT(filter=filter, img_size=(1, 16, 16)) >>> y = physics(x) >>> y[y<1e-5] = 0. >>> y[:, :, 7:10, 7:10] # Display the center of the blurred image tensor([[[[0.2500, 0.2500, 0.0000], [0.2500, 0.2500, 0.0000], [0.0000, 0.0000, 0.0000]]]]) """ def __init__(self, img_size, filter: Tensor = None, device="cpu", **kwargs): super().__init__(**kwargs) self.img_size = img_size assert ( isinstance(filter, Tensor) or filter is None ), f"The filter must be a torch.Tensor or None, got filter of type {type(filter)}." self.update_parameters(filter=filter, **kwargs) self.to(device) def A(self, x: Tensor, filter: Tensor = None, **kwargs) -> Tensor: self.update_parameters(filter=filter, **kwargs) return super().A(x) def A_adjoint(self, x: Tensor, filter: Tensor = None, **kwargs) -> Tensor: self.update_parameters(filter=filter, **kwargs) return super().A_adjoint(x) def V_adjoint(self, x: Tensor) -> Tensor: return torch.view_as_real( fft.rfft2(x, norm="ortho") ) # make it a true SVD (see J. Romberg notes) def U(self, x): return fft.irfft2( torch.view_as_complex(x) * self.angle, norm="ortho", s=self.img_size[-2:], ) def U_adjoint(self, x): return torch.view_as_real( fft.rfft2(x, norm="ortho") * torch.conj(self.angle) ) # make it a true SVD (see J. Romberg notes) def V(self, x): return fft.irfft2(torch.view_as_complex(x), norm="ortho", s=self.img_size[-2:])
[docs] def update_parameters(self, filter: Tensor = None, **kwargs): r""" Updates the current filter. :param torch.Tensor filter: New filter to be applied to the input image. """ if filter is not None and isinstance(filter, Tensor): if self.img_size[0] > filter.shape[1]: filter = filter.repeat(1, self.img_size[0], 1, 1) mask = filter_fft_2d(filter, self.img_size) angle = torch.angle(mask) mask = torch.abs(mask).unsqueeze(-1) mask = torch.cat([mask, mask], dim=-1) self.register_buffer("filter", filter) self.register_buffer("angle", torch.exp(-1.0j * angle)) self.register_buffer("mask", mask) super().update_parameters(**kwargs)
[docs] class SpaceVaryingBlur(LinearPhysics): r""" Implements a space varying blur via product-convolution. This operator performs .. math:: y = \sum_{k=1}^K h_k \star (w_k \odot x) where :math:`\star` is a convolution, :math:`\odot` is a Hadamard product, :math:`w_k` are multipliers :math:`h_k` are filters. :param torch.Tensor w: Multipliers :math:`w_k`. Tensor of size (b, c, K, H, W). b in {1, B} and c in {1, C} :param torch.Tensor h: Filters :math:`h_k`. Tensor of size (b, c, K, h, w). b in {1, B} and c in {1, C}, h<=H and w<=W. :param padding: options = ``'valid'``, ``'circular'``, ``'replicate'``, ``'reflect'``. If ``padding = 'valid'`` the blurred output is smaller than the image (no padding), otherwise the blurred output has the same size as the image. :param str device: cpu or cuda |sep| :Examples: We show how to instantiate a spatially varying blur operator. >>> from deepinv.physics.generator import DiffractionBlurGenerator, ProductConvolutionBlurGenerator >>> from deepinv.physics.blur import SpaceVaryingBlur >>> from deepinv.utils.plotting import plot >>> psf_size = 32 >>> img_size = (256, 256) >>> delta = 16 >>> psf_generator = DiffractionBlurGenerator((psf_size, psf_size)) >>> pc_generator = ProductConvolutionBlurGenerator(psf_generator=psf_generator, img_size=img_size) >>> params_pc = pc_generator.step(1) >>> physics = SpaceVaryingBlur(**params_pc) >>> dirac_comb = torch.zeros(img_size).unsqueeze(0).unsqueeze(0) >>> dirac_comb[0,0,::delta,::delta] = 1 >>> psf_grid = physics(dirac_comb) >>> plot(psf_grid, titles="Space varying impulse responses") """ def __init__( self, filters: Tensor = None, multipliers: Tensor = None, padding: str = None, device="cpu", **kwargs, ): super().__init__(**kwargs) self.method = "product_convolution2d" if self.method == "product_convolution2d": self.update_parameters(filters, multipliers, padding, **kwargs) self.to(device)
[docs] def A( self, x: Tensor, filters=None, multipliers=None, padding=None, **kwargs ) -> torch.Tensor: r""" Applies the space varying blur operator to the input image. It can receive new parameters :math:`w_k`, :math:`h_k` and padding to be used in the forward operator, and stored as the current parameters. :param torch.Tensor filters: Multipliers :math:`w_k`. Tensor of size (b, c, K, H, W). b in {1, B} and c in {1, C} :param torch.Tensor multipliers: Filters :math:`h_k`. Tensor of size (b, c, K, h, w). b in {1, B} and c in {1, C}, h<=H and w<=W :param padding: options = ``'valid'``, ``'circular'``, ``'replicate'``, ``'reflect'``. If `padding = 'valid'` the blurred output is smaller than the image (no padding), otherwise the blurred output has the same size as the image. :param str device: cpu or cuda """ if self.method == "product_convolution2d": self.update_parameters(filters, multipliers, padding, **kwargs) return product_convolution2d( x, self.multipliers, self.filters, self.padding ) else: raise NotImplementedError("Method not implemented in product-convolution")
[docs] def A_adjoint( self, y: Tensor, filters=None, multipliers=None, padding=None, **kwargs ) -> torch.Tensor: r""" Applies the adjoint operator. It can receive new parameters :math:`w_k`, :math:`h_k` and padding to be used in the forward operator, and stored as the current parameters. :param torch.Tensor h: Filters :math:`h_k`. Tensor of size (b, c, K, h, w). b in {1, B} and c in {1, C}, h<=H and w<=W :param torch.Tensor w: Multipliers :math:`w_k`. Tensor of size (b, c, K, H, W). b in {1, B} and c in {1, C} :param padding: options = ``'valid'``, ``'circular'``, ``'replicate'``, ``'reflect'``. If `padding = 'valid'` the blurred output is smaller than the image (no padding), otherwise the blurred output has the same size as the image. :param str device: cpu or cuda """ if self.method == "product_convolution2d": self.update_parameters( filters=filters, multipliers=multipliers, padding=padding, **kwargs ) return product_convolution2d_adjoint( y, self.multipliers, self.filters, self.padding ) else: raise NotImplementedError("Method not implemented in product-convolution")
[docs] def update_parameters( self, filters: Tensor = None, multipliers: Tensor = None, padding: str = None, **kwargs, ): r""" Updates the current parameters. :param torch.Tensor filters: Multipliers :math:`w_k`. Tensor of size (b, c, K, H, W). b in {1, B} and c in {1, C} :param torch.Tensor multipliers: Filters :math:`h_k`. Tensor of size (b, c, K, h, w). b in {1, B} and c in {1, C}, h<=H and w<=W :param padding: options = ``'valid'``, ``'circular'``, ``'replicate'``, ``'reflect'``. """ if filters is not None and isinstance(filters, Tensor): self.register_buffer("filters", filters) if multipliers is not None and isinstance(filters, Tensor): self.register_buffer("multipliers", multipliers) if padding is not None: self.padding = padding super().update_parameters(**kwargs)
[docs] def gaussian_blur(sigma=(1, 1), angle=0): r""" Gaussian blur filter. Defined as .. math:: \begin{equation*} G(x, y) = \frac{1}{2\pi\sigma_x\sigma_y} \exp{\left(-\frac{x'^2}{2\sigma_x^2} - \frac{y'^2}{2\sigma_y^2}\right)} \end{equation*} where :math:`x'` and :math:`y'` are the rotated coordinates obtained by rotating $(x, y)$ around the origin by an angle :math:`\theta`: .. math:: \begin{align*} x' &= x \cos(\theta) - y \sin(\theta) \\ y' &= x \sin(\theta) + y \cos(\theta) \end{align*} with :math:`\sigma_x` and :math:`\sigma_y` the standard deviations along the :math:`x'` and :math:`y'` axes. :param float, tuple[float] sigma: standard deviation of the gaussian filter. If sigma is a float the filter is isotropic, whereas if sigma is a tuple of floats (sigma_x, sigma_y) the filter is anisotropic. :param float angle: rotation angle of the filter in degrees (only useful for anisotropic filters) """ if isinstance(sigma, (int, float)): sigma = (sigma, sigma) s = max(sigma) c = int(s / 0.3 + 1) k_size = 2 * c + 1 delta = torch.arange(k_size) x, y = torch.meshgrid(delta, delta, indexing="ij") x = x - c y = y - c filt = (x / sigma[0]).pow(2) filt += (y / sigma[1]).pow(2) filt = torch.exp(-filt / 2.0) filt = ( rotate( filt.unsqueeze(0).unsqueeze(0), angle, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) .squeeze(0) .squeeze(0) ) filt = filt / filt.flatten().sum() return filt.unsqueeze(0).unsqueeze(0)
def kaiser_window(beta, length, device="cpu"): """Return the Kaiser window of length `length` and shape parameter `beta`.""" if beta < 0: raise ValueError("beta must be greater than 0") if length < 1: raise ValueError("length must be greater than 0") if length == 1: return torch.tensor([1.0]) half = (length - 1) / 2 n = torch.arange(length, device=device) beta = torch.tensor(beta, device=device) return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta)
[docs] def sinc_filter(factor=2, length=11, windowed=True, device="cpu"): r""" Anti-aliasing sinc filter multiplied by a Kaiser window. The kaiser window parameter is computed as follows: .. math:: A = 2.285 \cdot (L - 1) \cdot 3.14 \cdot \Delta f + 7.95 where :math:`\Delta f = 2 (2 - \sqrt{2}) / \text{factor}`. Then, the beta parameter is computed as: .. math:: \begin{equation*} \beta = \begin{cases} 0 & \text{if } A \leq 21 \\ 0.5842 \cdot (A - 21)^{0.4} + 0.07886 \cdot (A - 21) & \text{if } 21 < A \leq 50 \\ 0.1102 \cdot (A - 8.7) & \text{otherwise} \end{cases} \end{equation*} :param float factor: Downsampling factor. :param int length: Length of the filter. """ if isinstance(factor, torch.Tensor): factor = factor.cpu().item() deltaf = 2 * (2 - 1.4142136) / factor n = torch.arange(length, device=device) - (length - 1) / 2 filter = torch.sinc(n / factor) if windowed: A = 2.285 * (length - 1) * 3.14159 * deltaf + 7.95 if A <= 21: beta = 0 elif A <= 50: beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21) else: beta = 0.1102 * (A - 8.7) filter = filter * kaiser_window(beta, length, device=device) filter = filter.unsqueeze(0) filter = filter * filter.T filter = filter.unsqueeze(0).unsqueeze(0) filter = filter / filter.sum() return filter
[docs] def bilinear_filter(factor=2): r""" Bilinear filter. It has size (2*factor, 2*factor) and is defined as .. math:: \begin{equation*} w(x, y) = \begin{cases} (1 - |x|) \cdot (1 - |y|) & \text{if } |x| \leq 1 \text{ and } |y| \leq 1 \\ 0 & \text{otherwise} \end{cases} \end{equation*} for :math:`x, y \in {-\text{factor} + 0.5, -\text{factor} + 0.5 + 1/\text{factor}, \ldots, \text{factor} - 0.5}`. :param int factor: downsampling factor """ if isinstance(factor, torch.Tensor): factor = factor.cpu().item() x = torch.arange(start=-factor + 0.5, end=factor, step=1) / factor w = 1 - x.abs() w = torch.outer(w, w) w = w / torch.sum(w) return w.unsqueeze(0).unsqueeze(0)
[docs] def bicubic_filter(factor=2): r""" Bicubic filter. It has size (4*factor, 4*factor) and is defined as .. math:: \begin{equation*} w(x, y) = \begin{cases} (a + 2)|x|^3 - (a + 3)|x|^2 + 1 & \text{if } |x| \leq 1 \\ a|x|^3 - 5a|x|^2 + 8a|x| - 4a & \text{if } 1 < |x| < 2 \\ 0 & \text{otherwise} \end{cases} \end{equation*} for :math:`x, y \in {-2\text{factor} + 0.5, -2\text{factor} + 0.5 + 1/\text{factor}, \ldots, 2\text{factor} - 0.5}`. :param int factor: downsampling factor """ if isinstance(factor, torch.Tensor): factor = factor.cpu().item() x = torch.arange(start=-2 * factor + 0.5, end=2 * factor, step=1) / factor a = -0.5 x = x.abs() w = ((a + 2) * x.pow(3) - (a + 3) * x.pow(2) + 1) * (x <= 1) w += (a * x.pow(3) - 5 * a * x.pow(2) + 8 * a * x - 4 * a) * (x > 1) * (x < 2) w = torch.outer(w, w) w = w / torch.sum(w) return w.unsqueeze(0).unsqueeze(0)