conv3d_fft#

deepinv.physics.functional.conv3d_fft(x, filter, real_fft=True, padding='valid')[source]#

A helper function performing the 3d convolution of x and filter using FFT.

The adjoint of this operation is deepinv.physics.functional.conv_transpose3d_fft().

If b = 1 or c = 1, this function applies the same filter for each channel and each image. Otherwise, each channel of each image is convolved with the corresponding kernel.

Parameters:
  • y (torch.Tensor) – Image of size (B, C, D, H, W).

  • filter (torch.Tensor) – Filter of size (b, c, d, h, w) where b can be either 1 or B and c can be either 1 or C.

  • real_fft (bool) – for real filters and images choose True (default) to accelerate computation

  • padding (str) – can be 'valid', 'circular', 'replicate', 'reflect', 'constant' or 'zeros'. If padding = 'valid' the output is smaller than the image (no padding), otherwise the output has the same size as the image. Note that 'constant' and 'zeros' are equivalent. Default is 'valid'.

Note

The filter center is located at (d//2, h//2, w//2).

This function and deepinv.physics.functional.conv3d() are equivalent. However, this function is more efficient for large filters but requires more memory.

Returns:

torch.Tensor: the output of the convolution, which has the same shape as \(x\) if padding = 'circular', (B, C, D-d+1, W-w+1, H-h+1) otherwise.

Return type:

Tensor