
class deepinv.models.WaveletDenoiser(level=3, wv='db8', device='cpu', non_linearity='soft', wvdim=2)[source]#

Bases: Denoiser

Orthogonal Wavelet denoising with the \(\ell_1\) norm.

This denoiser is defined as the solution to the optimization problem:

\[\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n\]

where \(\Psi\) is an orthonormal wavelet transform, \(\lambda>0\) is a hyperparameter, and where \(\|\cdot\|_n\) is either the \(\ell_1\) norm (non_linearity="soft") or the \(\ell_0\) norm (non_linearity="hard"). A variant of the \(\ell_0\) norm is also available (non_linearity="topk"), where the thresholding is done by keeping the \(k\) largest coefficients in each wavelet subband and setting the others to zero.

The solution is available in closed-form, thus the denoiser is cheap to compute.


This model requires Pytorch Wavelets (ptwt) to be installed. It can be installed with pip install ptwt.

  • level (int) – decomposition level of the wavelet transform

  • wv (str) – mother wavelet (follows the PyWavelets convention) (default: “db8”)

  • device (str) – cpu or gpu

  • non_linearity (str) – "soft", "hard" or "topk" thresholding (default: "soft"). If "topk", only the top-k wavelet coefficients are kept.

crop_output(x, padding)[source]#

Crop the output to make it compatible with the wavelet transform.


Applies the wavelet decomposition.


Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,).

forward(x, ths=0.1, **kwargs)[source]#

Run the model on a noisy image.

  • x (torch.Tensor) – noisy image.

  • ths (int, float, torch.Tensor) – thresholding parameter \(\gamma\). If ths is a tensor, it should be of shape (1, ) (same coefficent for all levels), (n_levels-1, ) (one coefficient per level), or (n_levels-1, 3) (one coefficient per subband and per level). If non_linearity equals "soft" or "hard", ths serves as a (soft or hard) thresholding parameter for the wavelet coefficients. If non_linearity equals "topk", ths can indicate the number of wavelet coefficients that are kept (if int) or the proportion of coefficients that are kept (if float).

hard_threshold_topk(x, ths=0.1)[source]#

Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0.

  • x (torch.Tensor) – wavelet coefficients.

  • ths (float, int) – top k coefficients to keep. If float, it is interpreted as a proportion of the total number of coefficients. If int, it is interpreted as the number of coefficients to keep.


Applies the wavelet recomposition.


Pad the input to make it compatible with the wavelet transform.

prox_l0(x, ths=0.1)[source]#

Hard thresholding of the wavelet coefficients.

prox_l1(x, ths=0.1)[source]#

Soft thresholding of the wavelet coefficients.

static psi(x, wavelet='db2', level=2, dimension=2)[source]#

Returns a flattened list containing the wavelet coefficients.

  • x (torch.Tensor) – input image.

  • wavelet (str) – mother wavelet.

  • level (int) – decomposition level.

  • dimension (int) – dimension of the wavelet transform (either 2 or 3).

reshape_ths(ths, level)[source]#
Reshape the thresholding parameter in the appropriate format, i.e. either:
  • a list of 3 elements, or

  • a tensor of 3 elements.

Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter, ths has shape (n_levels-1, 3).

threshold_3D(coeffs, ths)[source]#

Thresholds coefficients of the 3D wavelet transform.

threshold_ND(coeffs, ths)[source]#

Apply thresholding to the wavelet coefficients of arbitrary dimension.

thresold_2D(coeffs, ths)[source]#

Thresholds coefficients of the 2D wavelet transform.

thresold_func(x, ths)[source]#

” Apply thresholding to the wavelet coefficients.

Examples using WaveletDenoiser:#

Saving and loading models

Saving and loading models

3D wavelet denoising

3D wavelet denoising

Unfolded Chambolle-Pock for constrained image inpainting

Unfolded Chambolle-Pock for constrained image inpainting