WaveletDenoiser

class deepinv.models.WaveletDenoiser(level=3, wv='db8', device='cpu', non_linearity='soft', wvdim=2)[source]

Bases: Module

Orthogonal Wavelet denoising with the \(\ell_1\) norm.

This denoiser is defined as the solution to the optimization problem:

\[\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n\]

where \(\Psi\) is an orthonormal wavelet transform, \(\lambda>0\) is a hyperparameter, and where \(\|\cdot\|_n\) is either the \(\ell_1\) norm (non_linearity="soft") or the \(\ell_0\) norm (non_linearity="hard"). A variant of the \(\ell_0\) norm is also available (non_linearity="topk"), where the thresholding is done by keeping the \(k\) largest coefficients in each wavelet subband and setting the others to zero.

The solution is available in closed-form, thus the denoiser is cheap to compute.

Parameters:
  • level (int) – decomposition level of the wavelet transform

  • wv (str) – mother wavelet (follows the PyWavelets convention) (default: “db8”)

  • device (str) – cpu or gpu

  • non_linearity (str) – "soft", "hard" or "topk" thresholding (default: "soft"). If "topk", only the top-k wavelet coefficients are kept.

crop_output(x, padding)[source]

Crop the output to make it compatible with the wavelet transform.

dwt(x)[source]

Applies the wavelet decomposition.

flatten_coeffs(dec)[source]

Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,).

forward(x, ths=0.1)[source]

Run the model on a noisy image.

Parameters:
  • x (torch.Tensor) – noisy image.

  • ths (int, float, torch.Tensor) – thresholding parameter \(\gamma\). If ths is a tensor, it should be of shape (1, ) (same coefficent for all levels), (n_levels-1, ) (one coefficient per level), or (n_levels-1, 3) (one coefficient per subband and per level). If non_linearity equals "soft" or "hard", ths serves as a (soft or hard) thresholding parameter for the wavelet coefficients. If non_linearity equals "topk", ths can indicate the number of wavelet coefficients that are kept (if int) or the proportion of coefficients that are kept (if float).

hard_threshold_topk(x, ths=0.1)[source]

Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0.

Parameters:
  • x (torch.Tensor) – wavelet coefficients.

  • ths (float, int) – top k coefficients to keep. If float, it is interpreted as a proportion of the total number of coefficients. If int, it is interpreted as the number of coefficients to keep.

iwt(coeffs)[source]

Applies the wavelet recomposition.

pad_input(x)[source]

Pad the input to make it compatible with the wavelet transform.

prox_l0(x, ths=0.1)[source]

Hard thresholding of the wavelet coefficients.

Parameters:
prox_l1(x, ths=0.1)[source]

Soft thresholding of the wavelet coefficients.

Parameters:
static psi(x, wavelet='db2', level=2, dimension=2)[source]

Returns a flattened list containing the wavelet coefficients.

Parameters:
  • x (torch.Tensor) – input image.

  • wavelet (str) – mother wavelet.

  • level (int) – decomposition level.

  • dimension (int) – dimension of the wavelet transform (either 2 or 3).

reshape_ths(ths, level)[source]
Reshape the thresholding parameter in the appropriate format, i.e. either:
  • a list of 3 elements, or

  • a tensor of 3 elements.

Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter, ths has shape (n_levels-1, 3).

threshold_3D(coeffs, ths)[source]

Thresholds coefficients of the 3D wavelet transform.

threshold_ND(coeffs, ths)[source]

Apply thresholding to the wavelet coefficients of arbitrary dimension.

thresold_2D(coeffs, ths)[source]

Thresholds coefficients of the 2D wavelet transform.

thresold_func(x, ths)[source]

” Apply thresholding to the wavelet coefficients.

Examples using WaveletDenoiser:

Saving and loading models

Saving and loading models

3D wavelet denoising

3D wavelet denoising

Unfolded Chambolle-Pock for constrained image inpainting

Unfolded Chambolle-Pock for constrained image inpainting