WaveletDenoiser#
- class deepinv.models.WaveletDenoiser(level=3, wv='db8', device='cpu', non_linearity='soft', wvdim=2)[source]#
Bases:
Denoiser
Orthogonal Wavelet denoising with the \(\ell_1\) norm.
This denoiser is defined as the solution to the optimization problem:
\[\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n\]where \(\Psi\) is an orthonormal wavelet transform, \(\lambda>0\) is a hyperparameter, and where \(\|\cdot\|_n\) is either the \(\ell_1\) norm (
non_linearity="soft"
) or the \(\ell_0\) norm (non_linearity="hard"
). A variant of the \(\ell_0\) norm is also available (non_linearity="topk"
), where the thresholding is done by keeping the \(k\) largest coefficients in each wavelet subband and setting the others to zero.The solution is available in closed-form, thus the denoiser is cheap to compute.
Warning
This model requires Pytorch Wavelets (
ptwt
) to be installed. It can be installed withpip install ptwt
.- Parameters:
level (int) – decomposition level of the wavelet transform
wv (str) – mother wavelet (follows the PyWavelets convention) (default: “db8”)
device (str) – cpu or gpu
non_linearity (str) –
"soft"
,"hard"
or"topk"
thresholding (default:"soft"
). If"topk"
, only the top-k wavelet coefficients are kept.
- flatten_coeffs(dec)[source]#
Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,).
- forward(x, ths=0.1, **kwargs)[source]#
Run the model on a noisy image.
- Parameters:
x (torch.Tensor) – noisy image.
ths (int, float, torch.Tensor) – thresholding parameter \(\gamma\). If ths is a tensor, it should be of shape
(1, )
(same coefficent for all levels),(n_levels-1, )
(one coefficient per level), or(n_levels-1, 3)
(one coefficient per subband and per level). Ifnon_linearity
equals"soft"
or"hard"
,ths
serves as a (soft or hard) thresholding parameter for the wavelet coefficients. Ifnon_linearity
equals"topk"
,ths
can indicate the number of wavelet coefficients that are kept (ifint
) or the proportion of coefficients that are kept (iffloat
).
- hard_threshold_topk(x, ths=0.1)[source]#
Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0.
- Parameters:
x (torch.Tensor) – wavelet coefficients.
ths (float, int) – top k coefficients to keep. If
float
, it is interpreted as a proportion of the total number of coefficients. Ifint
, it is interpreted as the number of coefficients to keep.
- prox_l0(x, ths=0.1)[source]#
Hard thresholding of the wavelet coefficients.
- Parameters:
x (torch.Tensor) – wavelet coefficients.
ths (float, torch.Tensor) – threshold.
- prox_l1(x, ths=0.1)[source]#
Soft thresholding of the wavelet coefficients.
- Parameters:
x (torch.Tensor) – wavelet coefficients.
ths (float, torch.Tensor) – threshold.
- static psi(x, wavelet='db2', level=2, dimension=2)[source]#
Returns a flattened list containing the wavelet coefficients.
- Parameters:
x (torch.Tensor) – input image.
wavelet (str) – mother wavelet.
level (int) – decomposition level.
dimension (int) – dimension of the wavelet transform (either 2 or 3).
- reshape_ths(ths, level)[source]#
- Reshape the thresholding parameter in the appropriate format, i.e. either:
a list of 3 elements, or
a tensor of 3 elements.
Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter, ths has shape (n_levels-1, 3).
Examples using WaveletDenoiser
:#
Unfolded Chambolle-Pock for constrained image inpainting